TrafficWheel/utils/training_stats.py

89 lines
2.9 KiB
Python

# 新建 lib/training_stats.py
import time
import psutil
import torch
import os
class TrainingStats:
def __init__(self, device):
self.device = device
self.reset()
def reset(self):
self.gpu_mem_usage_list = []
self.cpu_mem_usage_list = []
self.train_time_list = []
self.infer_time_list = []
self.total_iters = 0
self.start_time = None
self.end_time = None
def start_training(self):
self.start_time = time.time()
def end_training(self):
self.end_time = time.time()
def record_step_time(self, duration, mode):
"""记录单步耗时和总迭代次数"""
if mode == "train":
self.train_time_list.append(duration)
else:
self.infer_time_list.append(duration)
self.total_iters += 1
def record_memory_usage(self):
"""记录当前 GPU 和 CPU 内存占用"""
process = psutil.Process(os.getpid())
cpu_mem = process.memory_info().rss / (1024**2)
if torch.cuda.is_available():
gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024**2)
torch.cuda.reset_peak_memory_stats(device=self.device)
else:
gpu_mem = 0.0
self.cpu_mem_usage_list.append(cpu_mem)
self.gpu_mem_usage_list.append(gpu_mem)
def report(self, logger):
"""在训练结束时输出汇总统计"""
if not self.start_time or not self.end_time:
logger.warning("TrainingStats: start/end time not recorded properly.")
return
total_time = self.end_time - self.start_time
avg_gpu_mem = (
sum(self.gpu_mem_usage_list) / len(self.gpu_mem_usage_list)
if self.gpu_mem_usage_list
else 0
)
avg_cpu_mem = (
sum(self.cpu_mem_usage_list) / len(self.cpu_mem_usage_list)
if self.cpu_mem_usage_list
else 0
)
avg_train_time = (
sum(self.train_time_list) / len(self.train_time_list)
if self.train_time_list
else 0
)
avg_infer_time = (
sum(self.infer_time_list) / len(self.infer_time_list)
if self.infer_time_list
else 0
)
iters_per_sec = self.total_iters / total_time if total_time > 0 else 0
logger.info("===== Training Summary =====")
logger.info(f"Total training time: {total_time:.2f} s")
logger.info(f"Total iterations: {self.total_iters}")
logger.info(f"Average iterations per second: {iters_per_sec:.2f}")
logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB")
logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB")
if avg_train_time:
logger.info(f"Average training step time: {avg_train_time * 1000:.2f} ms")
if avg_infer_time:
logger.info(f"Average inference step time: {avg_infer_time * 1000:.2f} ms")