import os import torch from datetime import datetime # import time from config.args_parser import parse_args import lib.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer def main(): args = parse_args() args = init.init_device(args) init.init_seed(args['train']['seed']) model = init.init_model(args) # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( args, normalizer=args['data']['normalizer'], single=False ) loss = init.init_loss(args, scaler) optimizer, lr_scheduler = init.init_optimizer(model, args['train']) init.create_logs(args) # Start training or testing trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, extra_data) match args['basic']['mode']: case 'train': trainer.train() case 'test': model.load_state_dict(torch.load( f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth", map_location=args['device'], weights_only=True)) trainer.test(model.to(args['basic']['device']), trainer.args, test_loader, scaler, trainer.logger) case _: raise ValueError(f"Unsupported mode: {args['basic']['mode']}") if __name__ == '__main__': from lib.Download_data import check_and_download_data data_complete = check_and_download_data() assert data_complete is not None, "数据集下载失败,请重试!" main()