# -*- coding: utf-8 -*- """ STDEN数据加载器 负责数据的加载、预处理和批处理 """ import numpy as np import torch from torch.utils.data import Dataset, DataLoader from pathlib import Path import logging class STDENDataset(Dataset): """STDEN数据集类""" def __init__(self, x_data, y_data, sequence_length, horizon): """ 初始化数据集 Args: x_data: 输入数据 y_data: 标签数据 sequence_length: 输入序列长度 horizon: 预测时间步数 """ self.x_data = x_data self.y_data = y_data self.sequence_length = sequence_length self.horizon = horizon self.num_samples = len(x_data) - sequence_length - horizon + 1 def __len__(self): return self.num_samples def __getitem__(self, idx): # 获取输入序列 x_seq = self.x_data[idx:idx + self.sequence_length] # 获取目标序列 y_seq = self.y_data[idx + self.sequence_length:idx + self.sequence_length + self.horizon] return torch.FloatTensor(x_seq), torch.FloatTensor(y_seq) class StandardScaler: """数据标准化器""" def __init__(self, mean=None, std=None): self.mean = mean self.std = std def fit(self, data): """拟合数据,计算均值和标准差""" self.mean = np.mean(data, axis=0) self.std = np.std(data, axis=0) # 避免除零 self.std[self.std == 0] = 1.0 def transform(self, data): """标准化数据""" return (data - self.mean) / self.std def inverse_transform(self, data): """反标准化数据""" return (data * self.std) + self.mean class STDENDataloader: """STDEN数据加载器主类""" def __init__(self, config): """ 初始化数据加载器 Args: config: 配置字典 """ self.config = config self.logger = logging.getLogger('STDEN') # 数据配置 self.dataset_dir = config['data']['dataset_dir'] self.batch_size = config['data']['batch_size'] self.val_batch_size = config['data'].get('val_batch_size', self.batch_size) self.sequence_length = config['model']['seq_len'] self.horizon = config['model']['horizon'] # 加载数据 self.data = self._load_dataset() self.scaler = self.data['scaler'] # 创建数据加载器 self.train_loader = self._create_dataloader( self.data['x_train'], self.data['y_train'], self.batch_size, shuffle=True ) self.val_loader = self._create_dataloader( self.data['x_val'], self.data['y_val'], self.val_batch_size, shuffle=False ) self.test_loader = self._create_dataloader( self.data['x_test'], self.data['y_test'], self.val_batch_size, shuffle=False ) self.logger.info(f"数据加载完成 - 训练集: {len(self.data['x_train'])} 样本") self.logger.info(f"验证集: {len(self.data['x_val'])} 样本, 测试集: {len(self.data['x_test'])} 样本") def _load_dataset(self): """加载数据集""" dataset_path = Path(self.dataset_dir) if not dataset_path.exists(): raise FileNotFoundError(f"数据集目录不存在: {self.dataset_dir}") # 检查数据集类型(BJ_GM或其他) if 'BJ' in self.dataset_dir: return self._load_bj_dataset(dataset_path) else: return self._load_standard_dataset(dataset_path) def _load_bj_dataset(self, dataset_path): """加载BJ_GM数据集""" flow_file = dataset_path / 'flow.npz' if not flow_file.exists(): raise FileNotFoundError(f"BJ数据集文件不存在: {flow_file}") data = dict(np.load(flow_file)) # 提取训练、验证、测试数据 x_train = data['x_train'] y_train = data['y_train'] x_val = data['x_val'] y_val = data['y_val'] x_test = data['x_test'] y_test = data['y_test'] return self._process_data(x_train, y_train, x_val, y_val, x_test, y_test) def _load_standard_dataset(self, dataset_path): """加载标准格式数据集""" data = {} for category in ['train', 'val', 'test']: cat_file = dataset_path / f"{category}.npz" if not cat_file.exists(): raise FileNotFoundError(f"数据集文件不存在: {cat_file}") cat_data = np.load(cat_file) data[f'x_{category}'] = cat_data['x'] data[f'y_{category}'] = cat_data['y'] return self._process_data( data['x_train'], data['y_train'], data['x_val'], data['y_val'], data['x_test'], data['y_test'] ) def _process_data(self, x_train, y_train, x_val, y_val, x_test, y_test): """处理数据(标准化等)""" # 创建标准化器 scaler = StandardScaler() scaler.fit(x_train) # 标准化所有数据 x_train_scaled = scaler.transform(x_train) y_train_scaled = scaler.transform(y_train) x_val_scaled = scaler.transform(x_val) y_val_scaled = scaler.transform(y_val) x_test_scaled = scaler.transform(x_test) y_test_scaled = scaler.transform(y_test) return { 'x_train': x_train_scaled, 'y_train': y_train_scaled, 'x_val': x_val_scaled, 'y_val': y_val_scaled, 'x_test': x_test_scaled, 'y_test': y_test_scaled, 'scaler': scaler } def _create_dataloader(self, x_data, y_data, batch_size, shuffle=False): """创建PyTorch数据加载器""" dataset = STDENDataset(x_data, y_data, self.sequence_length, self.horizon) return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, # 避免多进程问题 drop_last=False ) def get_data_loaders(self): """获取所有数据加载器""" return { 'train': self.train_loader, 'val': self.val_loader, 'test': self.test_loader } def get_scaler(self): """获取标准化器""" return self.scaler def get_data_info(self): """获取数据信息""" return { 'input_dim': self.data['x_train'].shape[-1], 'output_dim': self.data['y_train'].shape[-1], 'sequence_length': self.sequence_length, 'horizon': self.horizon, 'num_nodes': self.data['x_train'].shape[1] if len(self.data['x_train'].shape) > 1 else 1 }