TrafficWheel/run.py

57 lines
1.7 KiB
Python
Executable File

import torch
from utils.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
# import time
from config.args_parser import parse_args
import utils.initializer as init
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
def main():
# 读取配置
args = parse_args()
# 初始化 device, seed, model, data, trainer
args = init.init_device(args)
init.init_seed(args["basic"]["seed"])
model = init.init_model(args)
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
args, normalizer=args["data"]["normalizer"], single=False
)
loss = init.init_loss(args, scaler)
optimizer, lr_scheduler = init.init_optimizer(model, args["train"])
init.create_logs(args)
trainer = select_trainer(
model,
loss, optimizer,
train_loader, val_loader, test_loader, scaler,
args,
lr_scheduler, extra_data,
)
# 开始训练
match args["basic"]["mode"]:
case "train":
trainer.train()
case "test":
model.load_state_dict(
torch.load(
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
map_location=args["basic"]["device"],
weights_only=True,
)
)
trainer.test(
model.to(args["basic"]["device"]),
trainer.args, test_loader, scaler,
trainer.logger,
)
case _:
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
if __name__ == "__main__":
main()