195 lines
7.1 KiB
Python
195 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
数据集分析器
|
|
用于读取BasicTS项目中的数据集并生成详细的报告
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
from datetime import datetime
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
class DatasetAnalyzer:
|
|
"""数据集分析器"""
|
|
|
|
def __init__(self, datasets_dir: str = "datasets"):
|
|
"""
|
|
初始化数据集分析器
|
|
|
|
Args:
|
|
datasets_dir: 数据集目录路径
|
|
"""
|
|
self.datasets_dir = Path(datasets_dir)
|
|
self.datasets_info = {}
|
|
self.analysis_results = {}
|
|
|
|
def get_available_datasets(self) -> List[str]:
|
|
"""获取可用的数据集列表"""
|
|
datasets = []
|
|
for item in self.datasets_dir.iterdir():
|
|
if item.is_dir() and (item / "desc.json").exists():
|
|
datasets.append(item.name)
|
|
return sorted(datasets)
|
|
|
|
def load_dataset_description(self, dataset_name: str) -> Dict:
|
|
"""加载数据集描述文件"""
|
|
desc_path = self.datasets_dir / dataset_name / "desc.json"
|
|
with open(desc_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def load_dataset_data(self, dataset_name: str) -> np.ndarray:
|
|
"""加载数据集数据"""
|
|
desc = self.load_dataset_description(dataset_name)
|
|
data_path = self.datasets_dir / dataset_name / "data.dat"
|
|
|
|
# 使用memmap加载大数据文件
|
|
data = np.memmap(data_path, dtype='float32', mode='r',
|
|
shape=tuple(desc['shape']))
|
|
return data.copy() # 复制到内存中
|
|
|
|
def load_adjacency_matrix(self, dataset_name: str) -> Optional[np.ndarray]:
|
|
"""加载邻接矩阵(如果存在)"""
|
|
adj_path = self.datasets_dir / dataset_name / "adj_mx.pkl"
|
|
if adj_path.exists():
|
|
with open(adj_path, 'rb') as f:
|
|
adj_data = pickle.load(f)
|
|
# 处理不同的邻接矩阵格式
|
|
if isinstance(adj_data, tuple):
|
|
return adj_data[0] # 通常第一个元素是邻接矩阵
|
|
elif isinstance(adj_data, dict):
|
|
return adj_data.get('adj_mx', adj_data.get('adj', None))
|
|
else:
|
|
return adj_data
|
|
return None
|
|
|
|
def analyze_missing_values(self, data: np.ndarray, null_val: float = 0.0) -> Dict:
|
|
"""分析缺失值"""
|
|
# 计算缺失值
|
|
if np.isnan(null_val):
|
|
missing_mask = np.isnan(data)
|
|
else:
|
|
missing_mask = (data == null_val)
|
|
|
|
total_elements = data.size
|
|
missing_elements = np.sum(missing_mask)
|
|
missing_rate = (missing_elements / total_elements) * 100
|
|
|
|
# 按时间步分析缺失值
|
|
missing_by_time = np.sum(missing_mask, axis=(1, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=1)
|
|
missing_by_node = np.sum(missing_mask, axis=(0, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=0)
|
|
|
|
return {
|
|
'total_missing_rate': missing_rate,
|
|
'missing_elements': missing_elements,
|
|
'total_elements': total_elements,
|
|
'missing_by_time': missing_by_time,
|
|
'missing_by_node': missing_by_node,
|
|
'max_missing_time': np.max(missing_by_time),
|
|
'max_missing_node': np.max(missing_by_node) if data.ndim == 3 else 0
|
|
}
|
|
|
|
def analyze_temporal_continuity(self, data: np.ndarray, freq_minutes: int) -> Dict:
|
|
"""分析时间连续性"""
|
|
# 计算时间跨度
|
|
total_time_steps = data.shape[0]
|
|
total_hours = (total_time_steps * freq_minutes) / 60
|
|
total_days = total_hours / 24
|
|
|
|
# 计算数据密度(非零数据点比例)
|
|
non_zero_ratio = np.sum(data != 0) / data.size
|
|
|
|
return {
|
|
'total_time_steps': total_time_steps,
|
|
'frequency_minutes': freq_minutes,
|
|
'total_hours': total_hours,
|
|
'total_days': total_days,
|
|
'data_density': non_zero_ratio
|
|
}
|
|
|
|
def analyze_spatial_coverage(self, data: np.ndarray, adj_matrix: Optional[np.ndarray] = None) -> Dict:
|
|
"""分析空间覆盖"""
|
|
if data.ndim == 3:
|
|
num_nodes = data.shape[1]
|
|
num_features = data.shape[2]
|
|
else:
|
|
num_nodes = data.shape[1]
|
|
num_features = 1
|
|
|
|
# 计算邻接矩阵信息
|
|
edge_info = {}
|
|
if adj_matrix is not None:
|
|
num_edges = np.sum(adj_matrix > 0)
|
|
edge_density = num_edges / (num_nodes * num_nodes)
|
|
avg_degree = np.mean(np.sum(adj_matrix > 0, axis=1))
|
|
|
|
edge_info = {
|
|
'num_edges': int(num_edges),
|
|
'edge_density': edge_density,
|
|
'avg_degree': avg_degree,
|
|
'max_degree': int(np.max(np.sum(adj_matrix > 0, axis=1))),
|
|
'min_degree': int(np.min(np.sum(adj_matrix > 0, axis=1)))
|
|
}
|
|
|
|
return {
|
|
'num_nodes': num_nodes,
|
|
'num_features': num_features,
|
|
**edge_info
|
|
}
|
|
|
|
def analyze_dataset(self, dataset_name: str) -> Dict:
|
|
"""分析单个数据集"""
|
|
print(f"正在分析数据集: {dataset_name}")
|
|
|
|
# 加载数据
|
|
desc = self.load_dataset_description(dataset_name)
|
|
data = self.load_dataset_data(dataset_name)
|
|
adj_matrix = self.load_adjacency_matrix(dataset_name)
|
|
|
|
# 基础信息
|
|
basic_info = {
|
|
'name': desc['name'],
|
|
'domain': desc['domain'],
|
|
'shape': desc['shape'],
|
|
'has_graph': desc.get('has_graph', False),
|
|
'frequency_minutes': desc.get('frequency (minutes)', None)
|
|
}
|
|
|
|
# 缺失值分析
|
|
null_val = desc.get('regular_settings', {}).get('NULL_VAL', 0.0)
|
|
missing_analysis = self.analyze_missing_values(data, null_val)
|
|
|
|
# 时间连续性分析
|
|
temporal_analysis = self.analyze_temporal_continuity(data, basic_info['frequency_minutes'])
|
|
|
|
# 空间覆盖分析
|
|
spatial_analysis = self.analyze_spatial_coverage(data, adj_matrix)
|
|
|
|
return {
|
|
'basic_info': basic_info,
|
|
'missing_analysis': missing_analysis,
|
|
'temporal_analysis': temporal_analysis,
|
|
'spatial_analysis': spatial_analysis,
|
|
'description': desc
|
|
}
|
|
|
|
def analyze_all_datasets(self) -> Dict:
|
|
"""分析所有数据集"""
|
|
datasets = self.get_available_datasets()
|
|
print(f"发现 {len(datasets)} 个数据集: {datasets}")
|
|
|
|
for dataset_name in datasets:
|
|
try:
|
|
self.analysis_results[dataset_name] = self.analyze_dataset(dataset_name)
|
|
except Exception as e:
|
|
print(f"分析数据集 {dataset_name} 时出错: {e}")
|
|
continue
|
|
|
|
return self.analysis_results
|