109 lines
3.3 KiB
Python
Executable File
109 lines
3.3 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from model.model_selector import model_selector
|
|
from lib.loss_function import masked_mae_loss
|
|
import random
|
|
import numpy as np
|
|
from datetime import datetime
|
|
import os
|
|
import yaml
|
|
|
|
|
|
def init_model(args):
|
|
device = args["device"]
|
|
model = model_selector(args).to(device)
|
|
for p in model.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
else:
|
|
nn.init.uniform_(p)
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
print(f"模型参数量: {total_params} ")
|
|
return model
|
|
|
|
|
|
def init_optimizer(model, args):
|
|
optimizer = torch.optim.Adam(
|
|
params=model.parameters(),
|
|
lr=args["lr_init"],
|
|
eps=1.0e-8,
|
|
weight_decay=args["weight_decay"],
|
|
amsgrad=False,
|
|
)
|
|
|
|
lr_scheduler = None
|
|
if args["lr_decay"]:
|
|
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
|
|
)
|
|
|
|
return optimizer, lr_scheduler
|
|
|
|
|
|
def init_seed(seed):
|
|
"""初始化种子,保证结果可复现"""
|
|
torch.cuda.cudnn_enabled = False
|
|
torch.backends.cudnn.deterministic = True
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
def init_device(args):
|
|
device_name = args["basic"]["device"]
|
|
if "model" not in args or not isinstance(args["model"], dict):
|
|
args["model"] = {} # Ensure args['model'] is a dictionary
|
|
match device_name:
|
|
case "mps":
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
args["device"] = "mps"
|
|
else:
|
|
args["device"] = "cpu"
|
|
case device if device.startswith("cuda"):
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_device(int(device.split(":")[1]))
|
|
args["device"] = device
|
|
else:
|
|
args["device"] = "cpu"
|
|
case _:
|
|
args["device"] = "cpu"
|
|
args["model"]["device"] = args["device"]
|
|
return args
|
|
|
|
|
|
def init_loss(args, scaler):
|
|
device = args["basic"]["device"]
|
|
args = args["train"]
|
|
match args["loss_func"]:
|
|
case "mask_mae":
|
|
return masked_mae_loss(scaler, mask_value=None)
|
|
case "mae":
|
|
return torch.nn.L1Loss().to(device)
|
|
case "mse":
|
|
return torch.nn.MSELoss().to(device)
|
|
case "Huber":
|
|
return torch.nn.HuberLoss().to(device)
|
|
case _:
|
|
raise ValueError(f"Unsupported loss function: {args['loss_func']}")
|
|
|
|
|
|
def create_logs(args):
|
|
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
current_dir = os.path.abspath(os.getcwd())
|
|
|
|
model_name = args["basic"]["model"]
|
|
dataset_name = args["basic"]["dataset"]
|
|
|
|
args["train"]["log_dir"] = os.path.join(
|
|
current_dir, "experiments", model_name, dataset_name, current_time
|
|
)
|
|
config_filename = f"{args['basic']['dataset']}.yaml"
|
|
os.makedirs(args["train"]["log_dir"], exist_ok=True)
|
|
config_content = yaml.safe_dump(args, default_flow_style=False)
|
|
destination_path = os.path.join(args["train"]["log_dir"], config_filename)
|
|
# 将 args 保存为 YAML 文件
|
|
with open(destination_path, "w") as f:
|
|
f.write(config_content)
|