import torch from utils.Download_data import check_and_download_data data_complete = check_and_download_data() assert data_complete is not None, "数据集下载失败,请重试!" # import time from config.args_parser import parse_args import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer def main(): args = parse_args() args = init.init_device(args) init.init_seed(args["train"]["seed"]) model = init.init_model(args) # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( args, normalizer=args["data"]["normalizer"], single=False ) loss = init.init_loss(args, scaler) optimizer, lr_scheduler = init.init_optimizer(model, args["train"]) init.create_logs(args) # Start training or testing trainer = select_trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, extra_data, ) match args["basic"]["mode"]: case "train": trainer.train() case "test": model.load_state_dict( torch.load( f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth", map_location=args["device"], weights_only=True, ) ) trainer.test( model.to(args["basic"]["device"]), trainer.args, test_loader, scaler, trainer.logger, ) case _: raise ValueError(f"Unsupported mode: {args['basic']['mode']}") if __name__ == "__main__": main()