TrafficWheel/run.py

65 lines
1.8 KiB
Python
Executable File

import torch
from utils.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
# import time
from config.args_parser import parse_args
import utils.initializer as init
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
def main():
args = parse_args()
args = init.init_device(args)
init.init_seed(args["train"]["seed"])
model = init.init_model(args)
# Load dataset
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
args, normalizer=args["data"]["normalizer"], single=False
)
loss = init.init_loss(args, scaler)
optimizer, lr_scheduler = init.init_optimizer(model, args["train"])
init.create_logs(args)
# Start training or testing
trainer = select_trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
extra_data,
)
match args["basic"]["mode"]:
case "train":
trainer.train()
case "test":
model.load_state_dict(
torch.load(
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
map_location=args["device"],
weights_only=True,
)
)
trainer.test(
model.to(args["basic"]["device"]),
trainer.args,
test_loader,
scaler,
trainer.logger,
)
case _:
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
if __name__ == "__main__":
main()