REPST/scripts/dataset_analysis.py

488 lines
19 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据集分析脚本
用于读取BasicTS项目中的数据集并生成详细的报告
包括节点/边数量、时间频率、缺失值率、空间覆盖密度等分析
"""
import os
import json
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
class DatasetAnalyzer:
"""数据集分析器"""
def __init__(self, datasets_dir: str = "datasets"):
"""
初始化数据集分析器
Args:
datasets_dir: 数据集目录路径
"""
self.datasets_dir = Path(datasets_dir)
self.datasets_info = {}
self.analysis_results = {}
def get_available_datasets(self) -> List[str]:
"""获取可用的数据集列表"""
datasets = []
for item in self.datasets_dir.iterdir():
if item.is_dir() and (item / "desc.json").exists():
datasets.append(item.name)
return sorted(datasets)
def load_dataset_description(self, dataset_name: str) -> Dict:
"""加载数据集描述文件"""
desc_path = self.datasets_dir / dataset_name / "desc.json"
with open(desc_path, 'r', encoding='utf-8') as f:
return json.load(f)
def load_dataset_data(self, dataset_name: str) -> np.ndarray:
"""加载数据集数据"""
desc = self.load_dataset_description(dataset_name)
data_path = self.datasets_dir / dataset_name / "data.dat"
# 使用memmap加载大数据文件
data = np.memmap(data_path, dtype='float32', mode='r',
shape=tuple(desc['shape']))
return data.copy() # 复制到内存中
def load_adjacency_matrix(self, dataset_name: str) -> Optional[np.ndarray]:
"""加载邻接矩阵(如果存在)"""
adj_path = self.datasets_dir / dataset_name / "adj_mx.pkl"
if adj_path.exists():
with open(adj_path, 'rb') as f:
adj_data = pickle.load(f)
# 处理不同的邻接矩阵格式
if isinstance(adj_data, tuple):
return adj_data[0] # 通常第一个元素是邻接矩阵
elif isinstance(adj_data, dict):
return adj_data.get('adj_mx', adj_data.get('adj', None))
else:
return adj_data
return None
def analyze_missing_values(self, data: np.ndarray, null_val: float = 0.0) -> Dict:
"""分析缺失值"""
# 计算缺失值
if np.isnan(null_val):
missing_mask = np.isnan(data)
else:
missing_mask = (data == null_val)
total_elements = data.size
missing_elements = np.sum(missing_mask)
missing_rate = (missing_elements / total_elements) * 100
# 按时间步分析缺失值
missing_by_time = np.sum(missing_mask, axis=(1, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=1)
missing_by_node = np.sum(missing_mask, axis=(0, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=0)
return {
'total_missing_rate': missing_rate,
'missing_elements': missing_elements,
'total_elements': total_elements,
'missing_by_time': missing_by_time,
'missing_by_node': missing_by_node,
'max_missing_time': np.max(missing_by_time),
'max_missing_node': np.max(missing_by_node) if data.ndim == 3 else 0
}
def analyze_temporal_continuity(self, data: np.ndarray, freq_minutes: int) -> Dict:
"""分析时间连续性"""
# 计算时间跨度
total_time_steps = data.shape[0]
total_hours = (total_time_steps * freq_minutes) / 60
total_days = total_hours / 24
# 计算数据密度(非零数据点比例)
non_zero_ratio = np.sum(data != 0) / data.size
return {
'total_time_steps': total_time_steps,
'frequency_minutes': freq_minutes,
'total_hours': total_hours,
'total_days': total_days,
'data_density': non_zero_ratio
}
def analyze_spatial_coverage(self, data: np.ndarray, adj_matrix: Optional[np.ndarray] = None) -> Dict:
"""分析空间覆盖"""
if data.ndim == 3:
num_nodes = data.shape[1]
num_features = data.shape[2]
else:
num_nodes = data.shape[1]
num_features = 1
# 计算邻接矩阵信息
edge_info = {}
if adj_matrix is not None:
num_edges = np.sum(adj_matrix > 0)
edge_density = num_edges / (num_nodes * num_nodes)
avg_degree = np.mean(np.sum(adj_matrix > 0, axis=1))
edge_info = {
'num_edges': int(num_edges),
'edge_density': edge_density,
'avg_degree': avg_degree,
'max_degree': int(np.max(np.sum(adj_matrix > 0, axis=1))),
'min_degree': int(np.min(np.sum(adj_matrix > 0, axis=1)))
}
return {
'num_nodes': num_nodes,
'num_features': num_features,
**edge_info
}
def analyze_dataset(self, dataset_name: str) -> Dict:
"""分析单个数据集"""
print(f"正在分析数据集: {dataset_name}")
# 加载数据
desc = self.load_dataset_description(dataset_name)
data = self.load_dataset_data(dataset_name)
adj_matrix = self.load_adjacency_matrix(dataset_name)
# 基础信息
basic_info = {
'name': desc['name'],
'domain': desc['domain'],
'shape': desc['shape'],
'has_graph': desc.get('has_graph', False),
'frequency_minutes': desc.get('frequency (minutes)', None)
}
# 缺失值分析
null_val = desc.get('regular_settings', {}).get('NULL_VAL', 0.0)
missing_analysis = self.analyze_missing_values(data, null_val)
# 时间连续性分析
temporal_analysis = self.analyze_temporal_continuity(data, basic_info['frequency_minutes'])
# 空间覆盖分析
spatial_analysis = self.analyze_spatial_coverage(data, adj_matrix)
return {
'basic_info': basic_info,
'missing_analysis': missing_analysis,
'temporal_analysis': temporal_analysis,
'spatial_analysis': spatial_analysis,
'description': desc
}
def analyze_all_datasets(self) -> Dict:
"""分析所有数据集"""
datasets = self.get_available_datasets()
print(f"发现 {len(datasets)} 个数据集: {datasets}")
for dataset_name in datasets:
try:
self.analysis_results[dataset_name] = self.analyze_dataset(dataset_name)
except Exception as e:
print(f"分析数据集 {dataset_name} 时出错: {e}")
continue
return self.analysis_results
def generate_summary_report(self) -> str:
"""生成汇总报告"""
if not self.analysis_results:
return "没有可用的分析结果"
report = []
report.append("=" * 80)
report.append("BasicTS 数据集分析报告")
report.append("=" * 80)
report.append(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report.append(f"分析数据集数量: {len(self.analysis_results)}")
report.append("")
# 数据集概览表
report.append("数据集概览:")
report.append("-" * 80)
report.append(f"{'数据集名称':<15} {'领域':<20} {'时间步数':<10} {'节点数':<8} {'特征数':<8} {'频率(分钟)':<12} {'缺失值率(%)':<12}")
report.append("-" * 80)
for name, result in self.analysis_results.items():
basic = result['basic_info']
missing = result['missing_analysis']
spatial = result['spatial_analysis']
report.append(f"{name:<15} {basic['domain']:<20} {basic['shape'][0]:<10} "
f"{spatial['num_nodes']:<8} {spatial['num_features']:<8} "
f"{basic['frequency_minutes']:<12} {missing['total_missing_rate']:<12.3f}")
report.append("")
# 详细分析
for name, result in self.analysis_results.items():
report.append(f"数据集: {name}")
report.append("-" * 40)
basic = result['basic_info']
missing = result['missing_analysis']
temporal = result['temporal_analysis']
spatial = result['spatial_analysis']
report.append(f"领域: {basic['domain']}")
report.append(f"数据形状: {basic['shape']}")
report.append(f"时间频率: {basic['frequency_minutes']} 分钟")
report.append(f"时间跨度: {temporal['total_days']:.1f} 天 ({temporal['total_hours']:.1f} 小时)")
report.append(f"节点数量: {spatial['num_nodes']}")
report.append(f"特征数量: {spatial['num_features']}")
if spatial.get('num_edges'):
report.append(f"边数量: {spatial['num_edges']}")
report.append(f"边密度: {spatial['edge_density']:.4f}")
report.append(f"平均度数: {spatial['avg_degree']:.2f}")
report.append(f"缺失值率: {missing['total_missing_rate']:.3f}%")
report.append(f"数据密度: {temporal['data_density']:.3f}")
report.append("")
return "\n".join(report)
def generate_comparative_analysis(self) -> str:
"""生成对比分析报告"""
if not self.analysis_results:
return "没有可用的分析结果"
report = []
report.append("=" * 80)
report.append("数据集对比分析")
report.append("=" * 80)
report.append("")
# 按领域分组
domains = {}
for name, result in self.analysis_results.items():
domain = result['basic_info']['domain']
if domain not in domains:
domains[domain] = []
domains[domain].append((name, result))
for domain, datasets in domains.items():
report.append(f"领域: {domain}")
report.append("-" * 40)
# 该领域的数据集统计
missing_rates = [d[1]['missing_analysis']['total_missing_rate'] for d in datasets]
node_counts = [d[1]['spatial_analysis']['num_nodes'] for d in datasets]
time_steps = [d[1]['basic_info']['shape'][0] for d in datasets]
report.append(f"数据集数量: {len(datasets)}")
report.append(f"平均缺失值率: {np.mean(missing_rates):.3f}%")
report.append(f"缺失值率范围: {min(missing_rates):.3f}% - {max(missing_rates):.3f}%")
report.append(f"平均节点数: {np.mean(node_counts):.1f}")
report.append(f"节点数范围: {min(node_counts)} - {max(node_counts)}")
report.append(f"平均时间步数: {np.mean(time_steps):.0f}")
report.append("")
# 空间覆盖密度分析
report.append("空间覆盖密度分析:")
report.append("-" * 40)
spatial_datasets = [(name, result) for name, result in self.analysis_results.items()
if result['spatial_analysis'].get('num_edges')]
if spatial_datasets:
for name, result in spatial_datasets:
spatial = result['spatial_analysis']
report.append(f"{name}: {spatial['num_nodes']} 个节点, {spatial['num_edges']} 条边, "
f"密度 {spatial['edge_density']:.4f}, 平均度数 {spatial['avg_degree']:.2f}")
else:
report.append("没有发现包含图结构的数据集")
report.append("")
# 时间连续性分析
report.append("时间连续性分析:")
report.append("-" * 40)
temporal_data = []
for name, result in self.analysis_results.items():
temporal = result['temporal_analysis']
temporal_data.append({
'name': name,
'days': temporal['total_days'],
'density': temporal['data_density'],
'frequency': temporal['frequency_minutes']
})
# 按时间跨度排序
temporal_data.sort(key=lambda x: x['days'], reverse=True)
for data in temporal_data:
report.append(f"{data['name']}: {data['days']:.1f} 天, "
f"数据密度 {data['density']:.3f}, "
f"频率 {data['frequency']} 分钟")
return "\n".join(report)
def save_reports(self, output_dir: str = "analysis_reports"):
"""保存分析报告"""
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
# 保存汇总报告
summary_report = self.generate_summary_report()
with open(output_path / "summary_report.txt", 'w', encoding='utf-8') as f:
f.write(summary_report)
# 保存对比分析报告
comparative_report = self.generate_comparative_analysis()
with open(output_path / "comparative_analysis.txt", 'w', encoding='utf-8') as f:
f.write(comparative_report)
# 保存详细JSON报告
with open(output_path / "detailed_analysis.json", 'w', encoding='utf-8') as f:
json.dump(self.analysis_results, f, indent=2, ensure_ascii=False, default=str)
print(f"报告已保存到目录: {output_path}")
def create_visualizations(self, output_dir: str = "analysis_reports"):
"""创建可视化图表"""
if not self.analysis_results:
print("没有可用的分析结果")
return
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
# 设置图表样式
plt.style.use('seaborn-v0_8')
# 1. 缺失值率对比
fig, ax = plt.subplots(figsize=(12, 6))
names = list(self.analysis_results.keys())
missing_rates = [self.analysis_results[name]['missing_analysis']['total_missing_rate']
for name in names]
bars = ax.bar(names, missing_rates, color='skyblue', alpha=0.7)
ax.set_title('各数据集缺失值率对比', fontsize=14, fontweight='bold')
ax.set_xlabel('数据集名称')
ax.set_ylabel('缺失值率 (%)')
ax.tick_params(axis='x', rotation=45)
# 添加数值标签
for bar, rate in zip(bars, missing_rates):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
f'{rate:.2f}%', ha='center', va='bottom')
plt.tight_layout()
plt.savefig(output_path / "missing_rates_comparison.png", dpi=300, bbox_inches='tight')
plt.close()
# 2. 节点数量对比
fig, ax = plt.subplots(figsize=(12, 6))
node_counts = [self.analysis_results[name]['spatial_analysis']['num_nodes']
for name in names]
bars = ax.bar(names, node_counts, color='lightgreen', alpha=0.7)
ax.set_title('各数据集节点数量对比', fontsize=14, fontweight='bold')
ax.set_xlabel('数据集名称')
ax.set_ylabel('节点数量')
ax.tick_params(axis='x', rotation=45)
# 添加数值标签
for bar, count in zip(bars, node_counts):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(node_counts)*0.01,
f'{count}', ha='center', va='bottom')
plt.tight_layout()
plt.savefig(output_path / "node_counts_comparison.png", dpi=300, bbox_inches='tight')
plt.close()
# 3. 时间跨度对比
fig, ax = plt.subplots(figsize=(12, 6))
time_days = [self.analysis_results[name]['temporal_analysis']['total_days']
for name in names]
bars = ax.bar(names, time_days, color='orange', alpha=0.7)
ax.set_title('各数据集时间跨度对比', fontsize=14, fontweight='bold')
ax.set_xlabel('数据集名称')
ax.set_ylabel('时间跨度 (天)')
ax.tick_params(axis='x', rotation=45)
# 添加数值标签
for bar, days in zip(bars, time_days):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(time_days)*0.01,
f'{days:.1f}', ha='center', va='bottom')
plt.tight_layout()
plt.savefig(output_path / "time_span_comparison.png", dpi=300, bbox_inches='tight')
plt.close()
# 4. 散点图:节点数 vs 缺失值率
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(node_counts, missing_rates, s=100, alpha=0.7, c='red')
# 添加数据集标签
for i, name in enumerate(names):
ax.annotate(name, (node_counts[i], missing_rates[i]),
xytext=(5, 5), textcoords='offset points', fontsize=8)
ax.set_xlabel('节点数量')
ax.set_ylabel('缺失值率 (%)')
ax.set_title('节点数量与缺失值率关系', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path / "nodes_vs_missing_rates.png", dpi=300, bbox_inches='tight')
plt.close()
print(f"可视化图表已保存到目录: {output_path}")
def main():
"""主函数"""
print("BasicTS 数据集分析工具")
print("=" * 50)
# 创建分析器
analyzer = DatasetAnalyzer()
# 分析所有数据集
analyzer.analyze_all_datasets()
# 生成并打印报告
print("\n" + "=" * 80)
print("数据集分析报告")
print("=" * 80)
summary_report = analyzer.generate_summary_report()
print(summary_report)
print("\n" + "=" * 80)
print("对比分析报告")
print("=" * 80)
comparative_report = analyzer.generate_comparative_analysis()
print(comparative_report)
# 保存报告和可视化
analyzer.save_reports()
analyzer.create_visualizations()
print("\n分析完成!")
if __name__ == "__main__":
main()