116 lines
4.1 KiB
Python
Executable File
116 lines
4.1 KiB
Python
Executable File
import os
|
||
|
||
# 检查数据集完整性
|
||
from lib.Download_data import check_and_download_data
|
||
|
||
data_complete = check_and_download_data()
|
||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||
|
||
import torch
|
||
from datetime import datetime
|
||
# import time
|
||
from config.args_parser import parse_args
|
||
from lib.initializer import init_model, init_optimizer, init_seed
|
||
from lib.loss_function import get_loss_function
|
||
|
||
from dataloader.loader_selector import get_dataloader
|
||
from trainer.trainer_selector import select_trainer
|
||
import yaml
|
||
|
||
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
# Set device (prefer MPS on macOS, then CUDA, else CPU)
|
||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and args['device'] != 'cpu':
|
||
args['device'] = 'mps'
|
||
args['model']['device'] = args['device']
|
||
elif torch.cuda.is_available() and args['device'] != 'cpu':
|
||
torch.cuda.set_device(int(args['device'].split(':')[1]))
|
||
args['model']['device'] = args['device']
|
||
else:
|
||
args['device'] = 'cpu'
|
||
args['model']['device'] = args['device']
|
||
init_seed(args['train']['seed'])
|
||
# Initialize model
|
||
model = init_model(args['model'], device=args['device'])
|
||
|
||
|
||
|
||
if args['mode'] == "benchmark":
|
||
# 支持计算消耗分析,设置 mode为 benchmark
|
||
import torch.profiler as profiler
|
||
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
|
||
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
|
||
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
|
||
|
||
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
|
||
with profiler.profile(
|
||
activities=[
|
||
profiler.ProfilerActivity.CPU,
|
||
profiler.ProfilerActivity.CUDA
|
||
],
|
||
with_stack=True,
|
||
profile_memory=True,
|
||
record_shapes=True
|
||
) as prof:
|
||
out = model(dummy_input)
|
||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||
return 0
|
||
|
||
# Load dataset
|
||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||
args,
|
||
normalizer=args['data']['normalizer'],
|
||
single=False
|
||
)
|
||
|
||
# Initialize loss function
|
||
loss = get_loss_function(args['train'], scaler)
|
||
|
||
# Initialize optimizer and learning rate scheduler
|
||
optimizer, lr_scheduler = init_optimizer(model, args['train'])
|
||
|
||
# Configure log path
|
||
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time)
|
||
|
||
# 配置文件路径
|
||
config_filename = f"{args['data']['type']}.yaml"
|
||
# config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename)
|
||
# 确保日志目录存在
|
||
os.makedirs(args['train']['log_dir'], exist_ok=True)
|
||
|
||
# 生成配置文件内容(将 args 转换为 YAML 格式)
|
||
config_content = yaml.safe_dump(args, default_flow_style=False)
|
||
|
||
# 生成新的 YAML 文件名(例如:config.auto.yaml 或其他名称)
|
||
destination_path = os.path.join(args['train']['log_dir'], config_filename)
|
||
|
||
# 将 args 保存为 YAML 文件
|
||
with open(destination_path, 'w') as f:
|
||
f.write(config_content)
|
||
|
||
# Start training or testing
|
||
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||
lr_scheduler, extra_data)
|
||
|
||
match args['mode']:
|
||
case 'train':
|
||
trainer.train()
|
||
case 'test':
|
||
model.load_state_dict(torch.load(
|
||
f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth",
|
||
map_location=args['device'], weights_only=True))
|
||
# print(f"Loaded saved model on {args['device']}")
|
||
trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger)
|
||
case _:
|
||
raise ValueError(f"Unsupported mode: {args['mode']}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|