52 lines
1.6 KiB
Python
Executable File
52 lines
1.6 KiB
Python
Executable File
import os
|
|
|
|
import torch
|
|
from datetime import datetime
|
|
|
|
# import time
|
|
from config.args_parser import parse_args
|
|
import lib.initializer as init
|
|
from dataloader.loader_selector import get_dataloader
|
|
from trainer.trainer_selector import select_trainer
|
|
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
args = init.init_device(args)
|
|
init.init_seed(args['train']['seed'])
|
|
model = init.init_model(args)
|
|
|
|
# Load dataset
|
|
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
|
args,
|
|
normalizer=args['data']['normalizer'],
|
|
single=False
|
|
)
|
|
|
|
loss = init.init_loss(args, scaler)
|
|
optimizer, lr_scheduler = init.init_optimizer(model, args['train'])
|
|
init.create_logs(args)
|
|
|
|
# Start training or testing
|
|
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
|
lr_scheduler, extra_data)
|
|
|
|
match args['basic']['mode']:
|
|
case 'train':
|
|
trainer.train()
|
|
case 'test':
|
|
model.load_state_dict(torch.load(
|
|
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
|
|
map_location=args['device'], weights_only=True))
|
|
trainer.test(model.to(args['basic']['device']), trainer.args, test_loader, scaler, trainer.logger)
|
|
case _:
|
|
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from lib.Download_data import check_and_download_data
|
|
data_complete = check_and_download_data()
|
|
assert data_complete is not None, "数据集下载失败,请重试!"
|
|
main()
|