TrafficWheel/run.py

52 lines
1.6 KiB
Python
Executable File

import os
import torch
from datetime import datetime
# import time
from config.args_parser import parse_args
import lib.initializer as init
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
def main():
args = parse_args()
args = init.init_device(args)
init.init_seed(args['train']['seed'])
model = init.init_model(args)
# Load dataset
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
args,
normalizer=args['data']['normalizer'],
single=False
)
loss = init.init_loss(args, scaler)
optimizer, lr_scheduler = init.init_optimizer(model, args['train'])
init.create_logs(args)
# Start training or testing
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, extra_data)
match args['basic']['mode']:
case 'train':
trainer.train()
case 'test':
model.load_state_dict(torch.load(
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
map_location=args['device'], weights_only=True))
trainer.test(model.to(args['basic']['device']), trainer.args, test_loader, scaler, trainer.logger)
case _:
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
if __name__ == '__main__':
from lib.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
main()