Project-I/dataloader/stden_dataloader.py

216 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
STDEN数据加载器
负责数据的加载、预处理和批处理
"""
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import logging
class STDENDataset(Dataset):
"""STDEN数据集类"""
def __init__(self, x_data, y_data, sequence_length, horizon):
"""
初始化数据集
Args:
x_data: 输入数据
y_data: 标签数据
sequence_length: 输入序列长度
horizon: 预测时间步数
"""
self.x_data = x_data
self.y_data = y_data
self.sequence_length = sequence_length
self.horizon = horizon
self.num_samples = len(x_data) - sequence_length - horizon + 1
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 获取输入序列
x_seq = self.x_data[idx:idx + self.sequence_length]
# 获取目标序列
y_seq = self.y_data[idx + self.sequence_length:idx + self.sequence_length + self.horizon]
return torch.FloatTensor(x_seq), torch.FloatTensor(y_seq)
class StandardScaler:
"""数据标准化器"""
def __init__(self, mean=None, std=None):
self.mean = mean
self.std = std
def fit(self, data):
"""拟合数据,计算均值和标准差"""
self.mean = np.mean(data, axis=0)
self.std = np.std(data, axis=0)
# 避免除零
self.std[self.std == 0] = 1.0
def transform(self, data):
"""标准化数据"""
return (data - self.mean) / self.std
def inverse_transform(self, data):
"""反标准化数据"""
return (data * self.std) + self.mean
class STDENDataloader:
"""STDEN数据加载器主类"""
def __init__(self, config):
"""
初始化数据加载器
Args:
config: 配置字典
"""
self.config = config
self.logger = logging.getLogger('STDEN')
# 数据配置
self.dataset_dir = config['data']['dataset_dir']
self.batch_size = config['data']['batch_size']
self.val_batch_size = config['data'].get('val_batch_size', self.batch_size)
self.sequence_length = config['model']['seq_len']
self.horizon = config['model']['horizon']
# 加载数据
self.data = self._load_dataset()
self.scaler = self.data['scaler']
# 创建数据加载器
self.train_loader = self._create_dataloader(
self.data['x_train'], self.data['y_train'],
self.batch_size, shuffle=True
)
self.val_loader = self._create_dataloader(
self.data['x_val'], self.data['y_val'],
self.val_batch_size, shuffle=False
)
self.test_loader = self._create_dataloader(
self.data['x_test'], self.data['y_test'],
self.val_batch_size, shuffle=False
)
self.logger.info(f"数据加载完成 - 训练集: {len(self.data['x_train'])} 样本")
self.logger.info(f"验证集: {len(self.data['x_val'])} 样本, 测试集: {len(self.data['x_test'])} 样本")
def _load_dataset(self):
"""加载数据集"""
dataset_path = Path(self.dataset_dir)
if not dataset_path.exists():
raise FileNotFoundError(f"数据集目录不存在: {self.dataset_dir}")
# 检查数据集类型BJ_GM或其他
if 'BJ' in self.dataset_dir:
return self._load_bj_dataset(dataset_path)
else:
return self._load_standard_dataset(dataset_path)
def _load_bj_dataset(self, dataset_path):
"""加载BJ_GM数据集"""
flow_file = dataset_path / 'flow.npz'
if not flow_file.exists():
raise FileNotFoundError(f"BJ数据集文件不存在: {flow_file}")
data = dict(np.load(flow_file))
# 提取训练、验证、测试数据
x_train = data['x_train']
y_train = data['y_train']
x_val = data['x_val']
y_val = data['y_val']
x_test = data['x_test']
y_test = data['y_test']
return self._process_data(x_train, y_train, x_val, y_val, x_test, y_test)
def _load_standard_dataset(self, dataset_path):
"""加载标准格式数据集"""
data = {}
for category in ['train', 'val', 'test']:
cat_file = dataset_path / f"{category}.npz"
if not cat_file.exists():
raise FileNotFoundError(f"数据集文件不存在: {cat_file}")
cat_data = np.load(cat_file)
data[f'x_{category}'] = cat_data['x']
data[f'y_{category}'] = cat_data['y']
return self._process_data(
data['x_train'], data['y_train'],
data['x_val'], data['y_val'],
data['x_test'], data['y_test']
)
def _process_data(self, x_train, y_train, x_val, y_val, x_test, y_test):
"""处理数据(标准化等)"""
# 创建标准化器
scaler = StandardScaler()
scaler.fit(x_train)
# 标准化所有数据
x_train_scaled = scaler.transform(x_train)
y_train_scaled = scaler.transform(y_train)
x_val_scaled = scaler.transform(x_val)
y_val_scaled = scaler.transform(y_val)
x_test_scaled = scaler.transform(x_test)
y_test_scaled = scaler.transform(y_test)
return {
'x_train': x_train_scaled,
'y_train': y_train_scaled,
'x_val': x_val_scaled,
'y_val': y_val_scaled,
'x_test': x_test_scaled,
'y_test': y_test_scaled,
'scaler': scaler
}
def _create_dataloader(self, x_data, y_data, batch_size, shuffle=False):
"""创建PyTorch数据加载器"""
dataset = STDENDataset(x_data, y_data, self.sequence_length, self.horizon)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0, # 避免多进程问题
drop_last=False
)
def get_data_loaders(self):
"""获取所有数据加载器"""
return {
'train': self.train_loader,
'val': self.val_loader,
'test': self.test_loader
}
def get_scaler(self):
"""获取标准化器"""
return self.scaler
def get_data_info(self):
"""获取数据信息"""
return {
'input_dim': self.data['x_train'].shape[-1],
'output_dim': self.data['y_train'].shape[-1],
'sequence_length': self.sequence_length,
'horizon': self.horizon,
'num_nodes': self.data['x_train'].shape[1] if len(self.data['x_train'].shape) > 1 else 1
}