90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
import os
|
||
import shutil
|
||
|
||
# 检查数据集完整性
|
||
from lib.Download_data import check_and_download_data
|
||
data_complete = check_and_download_data()
|
||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||
|
||
import torch
|
||
from datetime import datetime
|
||
# import time
|
||
|
||
from config.args_parser import parse_args
|
||
from lib.initializer import init_model, init_optimizer
|
||
from lib.loss_function import get_loss_function
|
||
|
||
from dataloader.loader_selector import get_dataloader
|
||
from trainer.trainer_selector import select_trainer
|
||
import yaml
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
# Set device
|
||
if torch.cuda.is_available():
|
||
torch.cuda.set_device(int(args['device'].split(':')[1]))
|
||
args['model']['device'] = args['device']
|
||
else:
|
||
args['device'] = 'cpu'
|
||
args['model']['device'] = args['device']
|
||
|
||
|
||
# Initialize model
|
||
model = init_model(args['model'], device=args['device'])
|
||
|
||
# Load dataset
|
||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||
args,
|
||
normalizer=args['data']['normalizer'],
|
||
single=False
|
||
)
|
||
|
||
# Initialize loss function
|
||
loss = get_loss_function(args['train'], scaler)
|
||
|
||
# Initialize optimizer and learning rate scheduler
|
||
optimizer, lr_scheduler = init_optimizer(model, args['train'])
|
||
|
||
# Configure log path
|
||
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time)
|
||
|
||
# 配置文件路径
|
||
config_filename = f"{args['data']['type']}.yaml"
|
||
# config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename)
|
||
# 确保日志目录存在
|
||
os.makedirs(args['train']['log_dir'], exist_ok=True)
|
||
|
||
# 生成配置文件内容(将 args 转换为 YAML 格式)
|
||
config_content = yaml.safe_dump(args, default_flow_style=False)
|
||
|
||
# 生成新的 YAML 文件名(例如:config.auto.yaml 或其他名称)
|
||
destination_path = os.path.join(args['train']['log_dir'], config_filename)
|
||
|
||
# 将 args 保存为 YAML 文件
|
||
with open(destination_path, 'w') as f:
|
||
f.write(config_content)
|
||
|
||
# Start training or testing
|
||
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||
lr_scheduler, extra_data)
|
||
|
||
match args['mode']:
|
||
case 'train':
|
||
trainer.train()
|
||
case 'test':
|
||
model.load_state_dict(torch.load(
|
||
f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth",
|
||
map_location=args['device'], weights_only=True))
|
||
# print(f"Loaded saved model on {args['device']}")
|
||
trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger)
|
||
case _:
|
||
raise ValueError(f"Unsupported mode: {args['mode']}")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|