# 新建 lib/training_stats.py import time import psutil import torch import os class TrainingStats: def __init__(self, device): self.device = device self.reset() def reset(self): self.gpu_mem_usage_list = [] self.cpu_mem_usage_list = [] self.train_time_list = [] self.infer_time_list = [] self.total_iters = 0 self.start_time = None self.end_time = None def start_training(self): self.start_time = time.time() def end_training(self): self.end_time = time.time() def record_step_time(self, duration, mode): """记录单步耗时和总迭代次数""" if mode == 'train': self.train_time_list.append(duration) else: self.infer_time_list.append(duration) self.total_iters += 1 def record_memory_usage(self): """记录当前 GPU 和 CPU 内存占用""" process = psutil.Process(os.getpid()) cpu_mem = process.memory_info().rss / (1024 ** 2) if torch.cuda.is_available(): gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2) torch.cuda.reset_peak_memory_stats(device=self.device) else: gpu_mem = 0.0 self.cpu_mem_usage_list.append(cpu_mem) self.gpu_mem_usage_list.append(gpu_mem) def report(self, logger): """在训练结束时输出汇总统计""" if not self.start_time or not self.end_time: logger.warning("TrainingStats: start/end time not recorded properly.") return total_time = self.end_time - self.start_time avg_gpu_mem = sum(self.gpu_mem_usage_list) / len(self.gpu_mem_usage_list) if self.gpu_mem_usage_list else 0 avg_cpu_mem = sum(self.cpu_mem_usage_list) / len(self.cpu_mem_usage_list) if self.cpu_mem_usage_list else 0 avg_train_time = sum(self.train_time_list) / len(self.train_time_list) if self.train_time_list else 0 avg_infer_time = sum(self.infer_time_list) / len(self.infer_time_list) if self.infer_time_list else 0 iters_per_sec = self.total_iters / total_time if total_time > 0 else 0 logger.info("===== Training Summary =====") logger.info(f"Total training time: {total_time:.2f} s") logger.info(f"Total iterations: {self.total_iters}") logger.info(f"Average iterations per second: {iters_per_sec:.2f}") logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB") logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB") if avg_train_time: logger.info(f"Average training step time: {avg_train_time*1000:.2f} ms") if avg_infer_time: logger.info(f"Average inference step time: {avg_infer_time*1000:.2f} ms")