diff --git a/.gitignore b/.gitignore index d67c4d1..efc5a1f 100755 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ .DS_Store +Result.xlsx \ No newline at end of file diff --git a/Result.xlsx b/Result.xlsx index 045e098..2d7f91c 100755 Binary files a/Result.xlsx and b/Result.xlsx differ diff --git a/run.py b/run.py index 03a8ac9..9cb2df7 100755 --- a/run.py +++ b/run.py @@ -20,8 +20,11 @@ import yaml def main(): args = parse_args() - # Set device - if torch.cuda.is_available() and args['device'] != 'cpu': + # Set device (prefer MPS on macOS, then CUDA, else CPU) + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and args['device'] != 'cpu': + args['device'] = 'mps' + args['model']['device'] = args['device'] + elif torch.cuda.is_available() and args['device'] != 'cpu': torch.cuda.set_device(int(args['device'].split(':')[1])) args['model']['device'] = args['device'] else: