488 lines
19 KiB
Python
488 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
数据集分析脚本
|
|
用于读取BasicTS项目中的数据集并生成详细的报告
|
|
包括节点/边数量、时间频率、缺失值率、空间覆盖密度等分析
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from datetime import datetime, timedelta
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
# 设置中文字体
|
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
class DatasetAnalyzer:
|
|
"""数据集分析器"""
|
|
|
|
def __init__(self, datasets_dir: str = "datasets"):
|
|
"""
|
|
初始化数据集分析器
|
|
|
|
Args:
|
|
datasets_dir: 数据集目录路径
|
|
"""
|
|
self.datasets_dir = Path(datasets_dir)
|
|
self.datasets_info = {}
|
|
self.analysis_results = {}
|
|
|
|
def get_available_datasets(self) -> List[str]:
|
|
"""获取可用的数据集列表"""
|
|
datasets = []
|
|
for item in self.datasets_dir.iterdir():
|
|
if item.is_dir() and (item / "desc.json").exists():
|
|
datasets.append(item.name)
|
|
return sorted(datasets)
|
|
|
|
def load_dataset_description(self, dataset_name: str) -> Dict:
|
|
"""加载数据集描述文件"""
|
|
desc_path = self.datasets_dir / dataset_name / "desc.json"
|
|
with open(desc_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def load_dataset_data(self, dataset_name: str) -> np.ndarray:
|
|
"""加载数据集数据"""
|
|
desc = self.load_dataset_description(dataset_name)
|
|
data_path = self.datasets_dir / dataset_name / "data.dat"
|
|
|
|
# 使用memmap加载大数据文件
|
|
data = np.memmap(data_path, dtype='float32', mode='r',
|
|
shape=tuple(desc['shape']))
|
|
return data.copy() # 复制到内存中
|
|
|
|
def load_adjacency_matrix(self, dataset_name: str) -> Optional[np.ndarray]:
|
|
"""加载邻接矩阵(如果存在)"""
|
|
adj_path = self.datasets_dir / dataset_name / "adj_mx.pkl"
|
|
if adj_path.exists():
|
|
with open(adj_path, 'rb') as f:
|
|
adj_data = pickle.load(f)
|
|
# 处理不同的邻接矩阵格式
|
|
if isinstance(adj_data, tuple):
|
|
return adj_data[0] # 通常第一个元素是邻接矩阵
|
|
elif isinstance(adj_data, dict):
|
|
return adj_data.get('adj_mx', adj_data.get('adj', None))
|
|
else:
|
|
return adj_data
|
|
return None
|
|
|
|
def analyze_missing_values(self, data: np.ndarray, null_val: float = 0.0) -> Dict:
|
|
"""分析缺失值"""
|
|
# 计算缺失值
|
|
if np.isnan(null_val):
|
|
missing_mask = np.isnan(data)
|
|
else:
|
|
missing_mask = (data == null_val)
|
|
|
|
total_elements = data.size
|
|
missing_elements = np.sum(missing_mask)
|
|
missing_rate = (missing_elements / total_elements) * 100
|
|
|
|
# 按时间步分析缺失值
|
|
missing_by_time = np.sum(missing_mask, axis=(1, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=1)
|
|
missing_by_node = np.sum(missing_mask, axis=(0, 2)) if data.ndim == 3 else np.sum(missing_mask, axis=0)
|
|
|
|
return {
|
|
'total_missing_rate': missing_rate,
|
|
'missing_elements': missing_elements,
|
|
'total_elements': total_elements,
|
|
'missing_by_time': missing_by_time,
|
|
'missing_by_node': missing_by_node,
|
|
'max_missing_time': np.max(missing_by_time),
|
|
'max_missing_node': np.max(missing_by_node) if data.ndim == 3 else 0
|
|
}
|
|
|
|
def analyze_temporal_continuity(self, data: np.ndarray, freq_minutes: int) -> Dict:
|
|
"""分析时间连续性"""
|
|
# 计算时间跨度
|
|
total_time_steps = data.shape[0]
|
|
total_hours = (total_time_steps * freq_minutes) / 60
|
|
total_days = total_hours / 24
|
|
|
|
# 计算数据密度(非零数据点比例)
|
|
non_zero_ratio = np.sum(data != 0) / data.size
|
|
|
|
return {
|
|
'total_time_steps': total_time_steps,
|
|
'frequency_minutes': freq_minutes,
|
|
'total_hours': total_hours,
|
|
'total_days': total_days,
|
|
'data_density': non_zero_ratio
|
|
}
|
|
|
|
def analyze_spatial_coverage(self, data: np.ndarray, adj_matrix: Optional[np.ndarray] = None) -> Dict:
|
|
"""分析空间覆盖"""
|
|
if data.ndim == 3:
|
|
num_nodes = data.shape[1]
|
|
num_features = data.shape[2]
|
|
else:
|
|
num_nodes = data.shape[1]
|
|
num_features = 1
|
|
|
|
# 计算邻接矩阵信息
|
|
edge_info = {}
|
|
if adj_matrix is not None:
|
|
num_edges = np.sum(adj_matrix > 0)
|
|
edge_density = num_edges / (num_nodes * num_nodes)
|
|
avg_degree = np.mean(np.sum(adj_matrix > 0, axis=1))
|
|
|
|
edge_info = {
|
|
'num_edges': int(num_edges),
|
|
'edge_density': edge_density,
|
|
'avg_degree': avg_degree,
|
|
'max_degree': int(np.max(np.sum(adj_matrix > 0, axis=1))),
|
|
'min_degree': int(np.min(np.sum(adj_matrix > 0, axis=1)))
|
|
}
|
|
|
|
return {
|
|
'num_nodes': num_nodes,
|
|
'num_features': num_features,
|
|
**edge_info
|
|
}
|
|
|
|
def analyze_dataset(self, dataset_name: str) -> Dict:
|
|
"""分析单个数据集"""
|
|
print(f"正在分析数据集: {dataset_name}")
|
|
|
|
# 加载数据
|
|
desc = self.load_dataset_description(dataset_name)
|
|
data = self.load_dataset_data(dataset_name)
|
|
adj_matrix = self.load_adjacency_matrix(dataset_name)
|
|
|
|
# 基础信息
|
|
basic_info = {
|
|
'name': desc['name'],
|
|
'domain': desc['domain'],
|
|
'shape': desc['shape'],
|
|
'has_graph': desc.get('has_graph', False),
|
|
'frequency_minutes': desc.get('frequency (minutes)', None)
|
|
}
|
|
|
|
# 缺失值分析
|
|
null_val = desc.get('regular_settings', {}).get('NULL_VAL', 0.0)
|
|
missing_analysis = self.analyze_missing_values(data, null_val)
|
|
|
|
# 时间连续性分析
|
|
temporal_analysis = self.analyze_temporal_continuity(data, basic_info['frequency_minutes'])
|
|
|
|
# 空间覆盖分析
|
|
spatial_analysis = self.analyze_spatial_coverage(data, adj_matrix)
|
|
|
|
return {
|
|
'basic_info': basic_info,
|
|
'missing_analysis': missing_analysis,
|
|
'temporal_analysis': temporal_analysis,
|
|
'spatial_analysis': spatial_analysis,
|
|
'description': desc
|
|
}
|
|
|
|
def analyze_all_datasets(self) -> Dict:
|
|
"""分析所有数据集"""
|
|
datasets = self.get_available_datasets()
|
|
print(f"发现 {len(datasets)} 个数据集: {datasets}")
|
|
|
|
for dataset_name in datasets:
|
|
try:
|
|
self.analysis_results[dataset_name] = self.analyze_dataset(dataset_name)
|
|
except Exception as e:
|
|
print(f"分析数据集 {dataset_name} 时出错: {e}")
|
|
continue
|
|
|
|
return self.analysis_results
|
|
|
|
def generate_summary_report(self) -> str:
|
|
"""生成汇总报告"""
|
|
if not self.analysis_results:
|
|
return "没有可用的分析结果"
|
|
|
|
report = []
|
|
report.append("=" * 80)
|
|
report.append("BasicTS 数据集分析报告")
|
|
report.append("=" * 80)
|
|
report.append(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
report.append(f"分析数据集数量: {len(self.analysis_results)}")
|
|
report.append("")
|
|
|
|
# 数据集概览表
|
|
report.append("数据集概览:")
|
|
report.append("-" * 80)
|
|
report.append(f"{'数据集名称':<15} {'领域':<20} {'时间步数':<10} {'节点数':<8} {'特征数':<8} {'频率(分钟)':<12} {'缺失值率(%)':<12}")
|
|
report.append("-" * 80)
|
|
|
|
for name, result in self.analysis_results.items():
|
|
basic = result['basic_info']
|
|
missing = result['missing_analysis']
|
|
spatial = result['spatial_analysis']
|
|
|
|
report.append(f"{name:<15} {basic['domain']:<20} {basic['shape'][0]:<10} "
|
|
f"{spatial['num_nodes']:<8} {spatial['num_features']:<8} "
|
|
f"{basic['frequency_minutes']:<12} {missing['total_missing_rate']:<12.3f}")
|
|
|
|
report.append("")
|
|
|
|
# 详细分析
|
|
for name, result in self.analysis_results.items():
|
|
report.append(f"数据集: {name}")
|
|
report.append("-" * 40)
|
|
|
|
basic = result['basic_info']
|
|
missing = result['missing_analysis']
|
|
temporal = result['temporal_analysis']
|
|
spatial = result['spatial_analysis']
|
|
|
|
report.append(f"领域: {basic['domain']}")
|
|
report.append(f"数据形状: {basic['shape']}")
|
|
report.append(f"时间频率: {basic['frequency_minutes']} 分钟")
|
|
report.append(f"时间跨度: {temporal['total_days']:.1f} 天 ({temporal['total_hours']:.1f} 小时)")
|
|
report.append(f"节点数量: {spatial['num_nodes']}")
|
|
report.append(f"特征数量: {spatial['num_features']}")
|
|
|
|
if spatial.get('num_edges'):
|
|
report.append(f"边数量: {spatial['num_edges']}")
|
|
report.append(f"边密度: {spatial['edge_density']:.4f}")
|
|
report.append(f"平均度数: {spatial['avg_degree']:.2f}")
|
|
|
|
report.append(f"缺失值率: {missing['total_missing_rate']:.3f}%")
|
|
report.append(f"数据密度: {temporal['data_density']:.3f}")
|
|
report.append("")
|
|
|
|
return "\n".join(report)
|
|
|
|
def generate_comparative_analysis(self) -> str:
|
|
"""生成对比分析报告"""
|
|
if not self.analysis_results:
|
|
return "没有可用的分析结果"
|
|
|
|
report = []
|
|
report.append("=" * 80)
|
|
report.append("数据集对比分析")
|
|
report.append("=" * 80)
|
|
report.append("")
|
|
|
|
# 按领域分组
|
|
domains = {}
|
|
for name, result in self.analysis_results.items():
|
|
domain = result['basic_info']['domain']
|
|
if domain not in domains:
|
|
domains[domain] = []
|
|
domains[domain].append((name, result))
|
|
|
|
for domain, datasets in domains.items():
|
|
report.append(f"领域: {domain}")
|
|
report.append("-" * 40)
|
|
|
|
# 该领域的数据集统计
|
|
missing_rates = [d[1]['missing_analysis']['total_missing_rate'] for d in datasets]
|
|
node_counts = [d[1]['spatial_analysis']['num_nodes'] for d in datasets]
|
|
time_steps = [d[1]['basic_info']['shape'][0] for d in datasets]
|
|
|
|
report.append(f"数据集数量: {len(datasets)}")
|
|
report.append(f"平均缺失值率: {np.mean(missing_rates):.3f}%")
|
|
report.append(f"缺失值率范围: {min(missing_rates):.3f}% - {max(missing_rates):.3f}%")
|
|
report.append(f"平均节点数: {np.mean(node_counts):.1f}")
|
|
report.append(f"节点数范围: {min(node_counts)} - {max(node_counts)}")
|
|
report.append(f"平均时间步数: {np.mean(time_steps):.0f}")
|
|
report.append("")
|
|
|
|
# 空间覆盖密度分析
|
|
report.append("空间覆盖密度分析:")
|
|
report.append("-" * 40)
|
|
|
|
spatial_datasets = [(name, result) for name, result in self.analysis_results.items()
|
|
if result['spatial_analysis'].get('num_edges')]
|
|
|
|
if spatial_datasets:
|
|
for name, result in spatial_datasets:
|
|
spatial = result['spatial_analysis']
|
|
report.append(f"{name}: {spatial['num_nodes']} 个节点, {spatial['num_edges']} 条边, "
|
|
f"密度 {spatial['edge_density']:.4f}, 平均度数 {spatial['avg_degree']:.2f}")
|
|
else:
|
|
report.append("没有发现包含图结构的数据集")
|
|
|
|
report.append("")
|
|
|
|
# 时间连续性分析
|
|
report.append("时间连续性分析:")
|
|
report.append("-" * 40)
|
|
|
|
temporal_data = []
|
|
for name, result in self.analysis_results.items():
|
|
temporal = result['temporal_analysis']
|
|
temporal_data.append({
|
|
'name': name,
|
|
'days': temporal['total_days'],
|
|
'density': temporal['data_density'],
|
|
'frequency': temporal['frequency_minutes']
|
|
})
|
|
|
|
# 按时间跨度排序
|
|
temporal_data.sort(key=lambda x: x['days'], reverse=True)
|
|
|
|
for data in temporal_data:
|
|
report.append(f"{data['name']}: {data['days']:.1f} 天, "
|
|
f"数据密度 {data['density']:.3f}, "
|
|
f"频率 {data['frequency']} 分钟")
|
|
|
|
return "\n".join(report)
|
|
|
|
def save_reports(self, output_dir: str = "analysis_reports"):
|
|
"""保存分析报告"""
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(exist_ok=True)
|
|
|
|
# 保存汇总报告
|
|
summary_report = self.generate_summary_report()
|
|
with open(output_path / "summary_report.txt", 'w', encoding='utf-8') as f:
|
|
f.write(summary_report)
|
|
|
|
# 保存对比分析报告
|
|
comparative_report = self.generate_comparative_analysis()
|
|
with open(output_path / "comparative_analysis.txt", 'w', encoding='utf-8') as f:
|
|
f.write(comparative_report)
|
|
|
|
# 保存详细JSON报告
|
|
with open(output_path / "detailed_analysis.json", 'w', encoding='utf-8') as f:
|
|
json.dump(self.analysis_results, f, indent=2, ensure_ascii=False, default=str)
|
|
|
|
print(f"报告已保存到目录: {output_path}")
|
|
|
|
def create_visualizations(self, output_dir: str = "analysis_reports"):
|
|
"""创建可视化图表"""
|
|
if not self.analysis_results:
|
|
print("没有可用的分析结果")
|
|
return
|
|
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(exist_ok=True)
|
|
|
|
# 设置图表样式
|
|
plt.style.use('seaborn-v0_8')
|
|
|
|
# 1. 缺失值率对比
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
names = list(self.analysis_results.keys())
|
|
missing_rates = [self.analysis_results[name]['missing_analysis']['total_missing_rate']
|
|
for name in names]
|
|
|
|
bars = ax.bar(names, missing_rates, color='skyblue', alpha=0.7)
|
|
ax.set_title('各数据集缺失值率对比', fontsize=14, fontweight='bold')
|
|
ax.set_xlabel('数据集名称')
|
|
ax.set_ylabel('缺失值率 (%)')
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
|
# 添加数值标签
|
|
for bar, rate in zip(bars, missing_rates):
|
|
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
|
|
f'{rate:.2f}%', ha='center', va='bottom')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path / "missing_rates_comparison.png", dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
# 2. 节点数量对比
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
node_counts = [self.analysis_results[name]['spatial_analysis']['num_nodes']
|
|
for name in names]
|
|
|
|
bars = ax.bar(names, node_counts, color='lightgreen', alpha=0.7)
|
|
ax.set_title('各数据集节点数量对比', fontsize=14, fontweight='bold')
|
|
ax.set_xlabel('数据集名称')
|
|
ax.set_ylabel('节点数量')
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
|
# 添加数值标签
|
|
for bar, count in zip(bars, node_counts):
|
|
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(node_counts)*0.01,
|
|
f'{count}', ha='center', va='bottom')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path / "node_counts_comparison.png", dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
# 3. 时间跨度对比
|
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
time_days = [self.analysis_results[name]['temporal_analysis']['total_days']
|
|
for name in names]
|
|
|
|
bars = ax.bar(names, time_days, color='orange', alpha=0.7)
|
|
ax.set_title('各数据集时间跨度对比', fontsize=14, fontweight='bold')
|
|
ax.set_xlabel('数据集名称')
|
|
ax.set_ylabel('时间跨度 (天)')
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
|
# 添加数值标签
|
|
for bar, days in zip(bars, time_days):
|
|
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(time_days)*0.01,
|
|
f'{days:.1f}', ha='center', va='bottom')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path / "time_span_comparison.png", dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
# 4. 散点图:节点数 vs 缺失值率
|
|
fig, ax = plt.subplots(figsize=(10, 6))
|
|
ax.scatter(node_counts, missing_rates, s=100, alpha=0.7, c='red')
|
|
|
|
# 添加数据集标签
|
|
for i, name in enumerate(names):
|
|
ax.annotate(name, (node_counts[i], missing_rates[i]),
|
|
xytext=(5, 5), textcoords='offset points', fontsize=8)
|
|
|
|
ax.set_xlabel('节点数量')
|
|
ax.set_ylabel('缺失值率 (%)')
|
|
ax.set_title('节点数量与缺失值率关系', fontsize=14, fontweight='bold')
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path / "nodes_vs_missing_rates.png", dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
print(f"可视化图表已保存到目录: {output_path}")
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
print("BasicTS 数据集分析工具")
|
|
print("=" * 50)
|
|
|
|
# 创建分析器
|
|
analyzer = DatasetAnalyzer()
|
|
|
|
# 分析所有数据集
|
|
analyzer.analyze_all_datasets()
|
|
|
|
# 生成并打印报告
|
|
print("\n" + "=" * 80)
|
|
print("数据集分析报告")
|
|
print("=" * 80)
|
|
|
|
summary_report = analyzer.generate_summary_report()
|
|
print(summary_report)
|
|
|
|
print("\n" + "=" * 80)
|
|
print("对比分析报告")
|
|
print("=" * 80)
|
|
|
|
comparative_report = analyzer.generate_comparative_analysis()
|
|
print(comparative_report)
|
|
|
|
# 保存报告和可视化
|
|
analyzer.save_reports()
|
|
analyzer.create_visualizations()
|
|
|
|
print("\n分析完成!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|