65 lines
1.8 KiB
Python
Executable File
65 lines
1.8 KiB
Python
Executable File
import torch
|
|
from utils.Download_data import check_and_download_data
|
|
data_complete = check_and_download_data()
|
|
assert data_complete is not None, "数据集下载失败,请重试!"
|
|
|
|
# import time
|
|
from config.args_parser import parse_args
|
|
import utils.initializer as init
|
|
from dataloader.loader_selector import get_dataloader
|
|
from trainer.trainer_selector import select_trainer
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
args = init.init_device(args)
|
|
init.init_seed(args["train"]["seed"])
|
|
model = init.init_model(args)
|
|
|
|
# Load dataset
|
|
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
|
args, normalizer=args["data"]["normalizer"], single=False
|
|
)
|
|
|
|
loss = init.init_loss(args, scaler)
|
|
optimizer, lr_scheduler = init.init_optimizer(model, args["train"])
|
|
init.create_logs(args)
|
|
|
|
# Start training or testing
|
|
trainer = select_trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
extra_data,
|
|
)
|
|
|
|
match args["basic"]["mode"]:
|
|
case "train":
|
|
trainer.train()
|
|
case "test":
|
|
model.load_state_dict(
|
|
torch.load(
|
|
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
|
|
map_location=args["device"],
|
|
weights_only=True,
|
|
)
|
|
)
|
|
trainer.test(
|
|
model.to(args["basic"]["device"]),
|
|
trainer.args,
|
|
test_loader,
|
|
scaler,
|
|
trainer.logger,
|
|
)
|
|
case _:
|
|
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|