TrafficWheel/model/STGNCDE/Make_model.py

36 lines
1.0 KiB
Python
Executable File

from model.STGNCDE.GCDE import *
def make_model(args):
vector_field_f = FinalTanh_f(
input_channels=args["input_dim"],
hidden_channels=args["hid_dim"],
hidden_hidden_channels=args["hid_hid_dim"],
num_hidden_layers=args["num_layers"],
)
vector_field_g = VectorField_g(
input_channels=args["input_dim"],
hidden_channels=args["hid_dim"],
hidden_hidden_channels=args["hid_hid_dim"],
num_hidden_layers=args["num_layers"],
num_nodes=args["num_nodes"],
cheb_k=args["cheb_k"],
embed_dim=args["embed_dim"],
g_type=args["g_type"],
)
model = NeuralGCDE(
args,
func_f=vector_field_f,
func_g=vector_field_g,
input_channels=args["input_dim"],
hidden_channels=args["hid_dim"],
output_channels=args["output_dim"],
initial=True,
device=args["device"],
atol=1e-9,
rtol=1e-7,
solver=args["solver"],
)
# return model, vector_field_f, vector_field_g
return model