216 lines
6.9 KiB
Python
216 lines
6.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
STDEN数据加载器
|
||
负责数据的加载、预处理和批处理
|
||
"""
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from pathlib import Path
|
||
import logging
|
||
|
||
|
||
class STDENDataset(Dataset):
|
||
"""STDEN数据集类"""
|
||
|
||
def __init__(self, x_data, y_data, sequence_length, horizon):
|
||
"""
|
||
初始化数据集
|
||
|
||
Args:
|
||
x_data: 输入数据
|
||
y_data: 标签数据
|
||
sequence_length: 输入序列长度
|
||
horizon: 预测时间步数
|
||
"""
|
||
self.x_data = x_data
|
||
self.y_data = y_data
|
||
self.sequence_length = sequence_length
|
||
self.horizon = horizon
|
||
self.num_samples = len(x_data) - sequence_length - horizon + 1
|
||
|
||
def __len__(self):
|
||
return self.num_samples
|
||
|
||
def __getitem__(self, idx):
|
||
# 获取输入序列
|
||
x_seq = self.x_data[idx:idx + self.sequence_length]
|
||
# 获取目标序列
|
||
y_seq = self.y_data[idx + self.sequence_length:idx + self.sequence_length + self.horizon]
|
||
|
||
return torch.FloatTensor(x_seq), torch.FloatTensor(y_seq)
|
||
|
||
|
||
class StandardScaler:
|
||
"""数据标准化器"""
|
||
|
||
def __init__(self, mean=None, std=None):
|
||
self.mean = mean
|
||
self.std = std
|
||
|
||
def fit(self, data):
|
||
"""拟合数据,计算均值和标准差"""
|
||
self.mean = np.mean(data, axis=0)
|
||
self.std = np.std(data, axis=0)
|
||
# 避免除零
|
||
self.std[self.std == 0] = 1.0
|
||
|
||
def transform(self, data):
|
||
"""标准化数据"""
|
||
return (data - self.mean) / self.std
|
||
|
||
def inverse_transform(self, data):
|
||
"""反标准化数据"""
|
||
return (data * self.std) + self.mean
|
||
|
||
|
||
class STDENDataloader:
|
||
"""STDEN数据加载器主类"""
|
||
|
||
def __init__(self, config):
|
||
"""
|
||
初始化数据加载器
|
||
|
||
Args:
|
||
config: 配置字典
|
||
"""
|
||
self.config = config
|
||
self.logger = logging.getLogger('STDEN')
|
||
|
||
# 数据配置
|
||
self.dataset_dir = config['data']['dataset_dir']
|
||
self.batch_size = config['data']['batch_size']
|
||
self.val_batch_size = config['data'].get('val_batch_size', self.batch_size)
|
||
self.sequence_length = config['model']['seq_len']
|
||
self.horizon = config['model']['horizon']
|
||
|
||
# 加载数据
|
||
self.data = self._load_dataset()
|
||
self.scaler = self.data['scaler']
|
||
|
||
# 创建数据加载器
|
||
self.train_loader = self._create_dataloader(
|
||
self.data['x_train'], self.data['y_train'],
|
||
self.batch_size, shuffle=True
|
||
)
|
||
self.val_loader = self._create_dataloader(
|
||
self.data['x_val'], self.data['y_val'],
|
||
self.val_batch_size, shuffle=False
|
||
)
|
||
self.test_loader = self._create_dataloader(
|
||
self.data['x_test'], self.data['y_test'],
|
||
self.val_batch_size, shuffle=False
|
||
)
|
||
|
||
self.logger.info(f"数据加载完成 - 训练集: {len(self.data['x_train'])} 样本")
|
||
self.logger.info(f"验证集: {len(self.data['x_val'])} 样本, 测试集: {len(self.data['x_test'])} 样本")
|
||
|
||
def _load_dataset(self):
|
||
"""加载数据集"""
|
||
dataset_path = Path(self.dataset_dir)
|
||
|
||
if not dataset_path.exists():
|
||
raise FileNotFoundError(f"数据集目录不存在: {self.dataset_dir}")
|
||
|
||
# 检查数据集类型(BJ_GM或其他)
|
||
if 'BJ' in self.dataset_dir:
|
||
return self._load_bj_dataset(dataset_path)
|
||
else:
|
||
return self._load_standard_dataset(dataset_path)
|
||
|
||
def _load_bj_dataset(self, dataset_path):
|
||
"""加载BJ_GM数据集"""
|
||
flow_file = dataset_path / 'flow.npz'
|
||
if not flow_file.exists():
|
||
raise FileNotFoundError(f"BJ数据集文件不存在: {flow_file}")
|
||
|
||
data = dict(np.load(flow_file))
|
||
|
||
# 提取训练、验证、测试数据
|
||
x_train = data['x_train']
|
||
y_train = data['y_train']
|
||
x_val = data['x_val']
|
||
y_val = data['y_val']
|
||
x_test = data['x_test']
|
||
y_test = data['y_test']
|
||
|
||
return self._process_data(x_train, y_train, x_val, y_val, x_test, y_test)
|
||
|
||
def _load_standard_dataset(self, dataset_path):
|
||
"""加载标准格式数据集"""
|
||
data = {}
|
||
|
||
for category in ['train', 'val', 'test']:
|
||
cat_file = dataset_path / f"{category}.npz"
|
||
if not cat_file.exists():
|
||
raise FileNotFoundError(f"数据集文件不存在: {cat_file}")
|
||
|
||
cat_data = np.load(cat_file)
|
||
data[f'x_{category}'] = cat_data['x']
|
||
data[f'y_{category}'] = cat_data['y']
|
||
|
||
return self._process_data(
|
||
data['x_train'], data['y_train'],
|
||
data['x_val'], data['y_val'],
|
||
data['x_test'], data['y_test']
|
||
)
|
||
|
||
def _process_data(self, x_train, y_train, x_val, y_val, x_test, y_test):
|
||
"""处理数据(标准化等)"""
|
||
# 创建标准化器
|
||
scaler = StandardScaler()
|
||
scaler.fit(x_train)
|
||
|
||
# 标准化所有数据
|
||
x_train_scaled = scaler.transform(x_train)
|
||
y_train_scaled = scaler.transform(y_train)
|
||
x_val_scaled = scaler.transform(x_val)
|
||
y_val_scaled = scaler.transform(y_val)
|
||
x_test_scaled = scaler.transform(x_test)
|
||
y_test_scaled = scaler.transform(y_test)
|
||
|
||
return {
|
||
'x_train': x_train_scaled,
|
||
'y_train': y_train_scaled,
|
||
'x_val': x_val_scaled,
|
||
'y_val': y_val_scaled,
|
||
'x_test': x_test_scaled,
|
||
'y_test': y_test_scaled,
|
||
'scaler': scaler
|
||
}
|
||
|
||
def _create_dataloader(self, x_data, y_data, batch_size, shuffle=False):
|
||
"""创建PyTorch数据加载器"""
|
||
dataset = STDENDataset(x_data, y_data, self.sequence_length, self.horizon)
|
||
|
||
return DataLoader(
|
||
dataset,
|
||
batch_size=batch_size,
|
||
shuffle=shuffle,
|
||
num_workers=0, # 避免多进程问题
|
||
drop_last=False
|
||
)
|
||
|
||
def get_data_loaders(self):
|
||
"""获取所有数据加载器"""
|
||
return {
|
||
'train': self.train_loader,
|
||
'val': self.val_loader,
|
||
'test': self.test_loader
|
||
}
|
||
|
||
def get_scaler(self):
|
||
"""获取标准化器"""
|
||
return self.scaler
|
||
|
||
def get_data_info(self):
|
||
"""获取数据信息"""
|
||
return {
|
||
'input_dim': self.data['x_train'].shape[-1],
|
||
'output_dim': self.data['y_train'].shape[-1],
|
||
'sequence_length': self.sequence_length,
|
||
'horizon': self.horizon,
|
||
'num_nodes': self.data['x_train'].shape[1] if len(self.data['x_train'].shape) > 1 else 1
|
||
}
|