36 lines
1.1 KiB
Python
Executable File
36 lines
1.1 KiB
Python
Executable File
from model.STGNRDE.GRDE import NeuralGCDE
|
|
from model.STGNRDE.vector_fields import FinalTanh_f, VectorField_g
|
|
|
|
|
|
def make_model(args):
|
|
vector_field_f = FinalTanh_f(
|
|
input_channels=args["input_dim"],
|
|
hidden_channels=args["hid_dim"],
|
|
hidden_hidden_channels=args["hid_hid_dim"],
|
|
num_hidden_layers=args["num_layers"],
|
|
)
|
|
vector_field_g = VectorField_g(
|
|
input_channels=args["input_dim"],
|
|
hidden_channels=args["hid_dim"],
|
|
hidden_hidden_channels=args["hid_hid_dim"],
|
|
num_hidden_layers=args["num_layers"],
|
|
num_nodes=args["num_nodes"],
|
|
cheb_k=args["cheb_k"],
|
|
embed_dim=args["embed_dim"],
|
|
g_type=args["g_type"],
|
|
)
|
|
model = NeuralGCDE(
|
|
args,
|
|
func_f=vector_field_f,
|
|
func_g=vector_field_g,
|
|
input_channels=args["input_dim"],
|
|
hidden_channels=args["hid_dim"],
|
|
output_channels=args["output_dim"],
|
|
initial=True,
|
|
device=args["device"],
|
|
atol=1e-9,
|
|
rtol=1e-7,
|
|
solver=args["solver"],
|
|
)
|
|
return model
|