diff --git a/run.py b/run.py index 5785dbf..0a4c49a 100755 --- a/run.py +++ b/run.py @@ -1,5 +1,7 @@ import torch - +from utils.Download_data import check_and_download_data +data_complete = check_and_download_data() +assert data_complete is not None, "数据集下载失败,请重试!" # import time from config.args_parser import parse_args @@ -58,10 +60,5 @@ def main(): case _: raise ValueError(f"Unsupported mode: {args['basic']['mode']}") - if __name__ == "__main__": - from utils.Download_data import check_and_download_data - - data_complete = check_and_download_data() - assert data_complete is not None, "数据集下载失败,请重试!" main()