89 lines
2.9 KiB
Python
89 lines
2.9 KiB
Python
# 新建 lib/training_stats.py
|
|
import time
|
|
import psutil
|
|
import torch
|
|
import os
|
|
|
|
|
|
class TrainingStats:
|
|
def __init__(self, device):
|
|
self.device = device
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.gpu_mem_usage_list = []
|
|
self.cpu_mem_usage_list = []
|
|
self.train_time_list = []
|
|
self.infer_time_list = []
|
|
self.total_iters = 0
|
|
self.start_time = None
|
|
self.end_time = None
|
|
|
|
def start_training(self):
|
|
self.start_time = time.time()
|
|
|
|
def end_training(self):
|
|
self.end_time = time.time()
|
|
|
|
def record_step_time(self, duration, mode):
|
|
"""记录单步耗时和总迭代次数"""
|
|
if mode == "train":
|
|
self.train_time_list.append(duration)
|
|
else:
|
|
self.infer_time_list.append(duration)
|
|
self.total_iters += 1
|
|
|
|
def record_memory_usage(self):
|
|
"""记录当前 GPU 和 CPU 内存占用"""
|
|
process = psutil.Process(os.getpid())
|
|
cpu_mem = process.memory_info().rss / (1024**2)
|
|
|
|
if torch.cuda.is_available():
|
|
gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024**2)
|
|
torch.cuda.reset_peak_memory_stats(device=self.device)
|
|
else:
|
|
gpu_mem = 0.0
|
|
|
|
self.cpu_mem_usage_list.append(cpu_mem)
|
|
self.gpu_mem_usage_list.append(gpu_mem)
|
|
|
|
def report(self, logger):
|
|
"""在训练结束时输出汇总统计"""
|
|
if not self.start_time or not self.end_time:
|
|
logger.warning("TrainingStats: start/end time not recorded properly.")
|
|
return
|
|
|
|
total_time = self.end_time - self.start_time
|
|
avg_gpu_mem = (
|
|
sum(self.gpu_mem_usage_list) / len(self.gpu_mem_usage_list)
|
|
if self.gpu_mem_usage_list
|
|
else 0
|
|
)
|
|
avg_cpu_mem = (
|
|
sum(self.cpu_mem_usage_list) / len(self.cpu_mem_usage_list)
|
|
if self.cpu_mem_usage_list
|
|
else 0
|
|
)
|
|
avg_train_time = (
|
|
sum(self.train_time_list) / len(self.train_time_list)
|
|
if self.train_time_list
|
|
else 0
|
|
)
|
|
avg_infer_time = (
|
|
sum(self.infer_time_list) / len(self.infer_time_list)
|
|
if self.infer_time_list
|
|
else 0
|
|
)
|
|
iters_per_sec = self.total_iters / total_time if total_time > 0 else 0
|
|
|
|
logger.info("===== Training Summary =====")
|
|
logger.info(f"Total training time: {total_time:.2f} s")
|
|
logger.info(f"Total iterations: {self.total_iters}")
|
|
logger.info(f"Average iterations per second: {iters_per_sec:.2f}")
|
|
logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB")
|
|
logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB")
|
|
if avg_train_time:
|
|
logger.info(f"Average training step time: {avg_train_time * 1000:.2f} ms")
|
|
if avg_infer_time:
|
|
logger.info(f"Average inference step time: {avg_infer_time * 1000:.2f} ms")
|