#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据集分析器 用于读取BasicTS项目中的数据集并生成详细的报告 """ import os import json import numpy as np import pandas as pd import pickle from pathlib import Path from typing import Dict, List, Tuple, Optional from datetime import datetime import warnings warnings.filterwarnings('ignore') class DatasetAnalyzer: """数据集分析器""" def __init__(self, datasets_dir: str = "datasets"): """ 初始化数据集分析器 Args: datasets_dir: 数据集目录路径 """ self.datasets_dir = Path(datasets_dir) self.datasets_info = {} self.analysis_results = {} def get_available_datasets(self) -> List[str]: """获取可用的数据集列表""" datasets = [] for item in self.datasets_dir.iterdir(): if item.is_dir() and (item / "desc.json").exists(): datasets.append(item.name) return sorted(datasets) def load_dataset_description(self, dataset_name: str) -> Dict: """加载数据集描述文件""" desc_path = self.datasets_dir / dataset_name / "desc.json" with open(desc_path, 'r', encoding='utf-8') as f: return json.load(f) def load_dataset_data(self, dataset_name: str) -> np.ndarray: """加载数据集数据""" desc = self.load_dataset_description(dataset_name) data_path = self.datasets_dir / dataset_name / "data.dat" # 使用memmap加载大数据文件 data = np.memmap(data_path, dtype='float32', mode='r', shape=tuple(desc['shape'])) return data.copy() # 复制到内存中 def load_adjacency_matrix(self, dataset_name: str) -> Optional[np.ndarray]: """加载邻接矩阵(如果存在)""" adj_path = self.datasets_dir / dataset_name / "adj_mx.pkl" if adj_path.exists(): with open(adj_path, 'rb') as f: adj_data = pickle.load(f) # 处理不同的邻接矩阵格式 if isinstance(adj_data, tuple): return adj_data[0] # 通常第一个元素是邻接矩阵 elif isinstance(adj_data, dict): return adj_data.get('adj_mx', adj_data.get('adj', None)) else: return adj_data return None def analyze_missing_values(self, data: np.ndarray, null_val: float = 0.0) -> Dict: """分析缺失值""" # 计算缺失值 if np.isnan(null_val): missing_mask = np.isnan(data) else: missing_mask = (data == null_val) total_elements = data.size missing_elements = np.sum(missing_mask) missing_rate = (missing_elements / total_elements) * 100 # 按时间步分析缺失值 missing_by_time = np.sum(missing_mask, axis=(1, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=1) missing_by_node = np.sum(missing_mask, axis=(0, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=0) return { 'total_missing_rate': missing_rate, 'missing_elements': missing_elements, 'total_elements': total_elements, 'missing_by_time': missing_by_time, 'missing_by_node': missing_by_node, 'max_missing_time': np.max(missing_by_time), 'max_missing_node': np.max(missing_by_node) if data.ndim == 3 else 0 } def analyze_temporal_continuity(self, data: np.ndarray, freq_minutes: int) -> Dict: """分析时间连续性""" # 计算时间跨度 total_time_steps = data.shape[0] total_hours = (total_time_steps * freq_minutes) / 60 total_days = total_hours / 24 # 计算数据密度(非零数据点比例) non_zero_ratio = np.sum(data != 0) / data.size return { 'total_time_steps': total_time_steps, 'frequency_minutes': freq_minutes, 'total_hours': total_hours, 'total_days': total_days, 'data_density': non_zero_ratio } def analyze_spatial_coverage(self, data: np.ndarray, adj_matrix: Optional[np.ndarray] = None) -> Dict: """分析空间覆盖""" if data.ndim == 3: num_nodes = data.shape[1] num_features = data.shape[2] else: num_nodes = data.shape[1] num_features = 1 # 计算邻接矩阵信息 edge_info = {} if adj_matrix is not None: num_edges = np.sum(adj_matrix > 0) edge_density = num_edges / (num_nodes * num_nodes) avg_degree = np.mean(np.sum(adj_matrix > 0, axis=1)) edge_info = { 'num_edges': int(num_edges), 'edge_density': edge_density, 'avg_degree': avg_degree, 'max_degree': int(np.max(np.sum(adj_matrix > 0, axis=1))), 'min_degree': int(np.min(np.sum(adj_matrix > 0, axis=1))) } return { 'num_nodes': num_nodes, 'num_features': num_features, **edge_info } def analyze_dataset(self, dataset_name: str) -> Dict: """分析单个数据集""" print(f"正在分析数据集: {dataset_name}") # 加载数据 desc = self.load_dataset_description(dataset_name) data = self.load_dataset_data(dataset_name) adj_matrix = self.load_adjacency_matrix(dataset_name) # 基础信息 basic_info = { 'name': desc['name'], 'domain': desc['domain'], 'shape': desc['shape'], 'has_graph': desc.get('has_graph', False), 'frequency_minutes': desc.get('frequency (minutes)', None) } # 缺失值分析 null_val = desc.get('regular_settings', {}).get('NULL_VAL', 0.0) missing_analysis = self.analyze_missing_values(data, null_val) # 时间连续性分析 temporal_analysis = self.analyze_temporal_continuity(data, basic_info['frequency_minutes']) # 空间覆盖分析 spatial_analysis = self.analyze_spatial_coverage(data, adj_matrix) return { 'basic_info': basic_info, 'missing_analysis': missing_analysis, 'temporal_analysis': temporal_analysis, 'spatial_analysis': spatial_analysis, 'description': desc } def analyze_all_datasets(self) -> Dict: """分析所有数据集""" datasets = self.get_available_datasets() print(f"发现 {len(datasets)} 个数据集: {datasets}") for dataset_name in datasets: try: self.analysis_results[dataset_name] = self.analyze_dataset(dataset_name) except Exception as e: print(f"分析数据集 {dataset_name} 时出错: {e}") continue return self.analysis_results