more clear
This commit is contained in:
parent
d3021c8112
commit
b335047b87
|
|
@ -49,20 +49,6 @@ def extend_model_cfg(cfg):
|
||||||
cfg.model.contrast_topk = 100
|
cfg.model.contrast_topk = 100
|
||||||
cfg.model.contrast_temp = 1.0
|
cfg.model.contrast_temp = 1.0
|
||||||
|
|
||||||
# Traffic Flow model parameters, These are only default values.
|
|
||||||
# Please modify the specific parameters directly in the baselines/YAML files.
|
|
||||||
cfg.model.num_nodes = 0
|
|
||||||
cfg.model.rnn_units = 64
|
|
||||||
cfg.model.dropout = 0.1
|
|
||||||
cfg.model.horizon = 12
|
|
||||||
cfg.model.input_dim = 1 # If 0, model will be built by data.shape
|
|
||||||
cfg.model.output_dim = 1
|
|
||||||
cfg.model.embed_dim = 10
|
|
||||||
cfg.model.num_layers = 1 # In GPR-GNN, K = layer
|
|
||||||
cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
|
|
||||||
cfg.model.use_day = True
|
|
||||||
cfg.model.use_week = True
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# Criterion related options
|
# Criterion related options
|
||||||
|
|
|
||||||
|
|
@ -12,24 +12,29 @@ def extend_trafficflow_cfg(cfg):
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# Model related options
|
# Model related options
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
cfg.model = CN()
|
cfg.model.tfp = CN()
|
||||||
|
|
||||||
cfg.model.model_num_per_trainer = 1 # some methods may leverage more
|
cfg.model.tfp.model_num_per_trainer = 1 # some methods may leverage more
|
||||||
# than one model in each trainer
|
# than one model in each trainer
|
||||||
cfg.model.type = 'trafficflow'
|
# cfg.tfp.model.type = 'trafficflow'
|
||||||
cfg.model.use_bias = True
|
# cfg.tfp.model.use_bias = True
|
||||||
cfg.model.task = 'trafficflowprediction'
|
# cfg.tfp.model.task = 'trafficflowprediction'
|
||||||
cfg.model.num_nodes = 0
|
cfg.model.tfp.num_nodes = 0
|
||||||
cfg.model.rnn_units = 64
|
cfg.model.tfp.rnn_units = 64
|
||||||
cfg.model.dropout = 0.1
|
cfg.model.tfp.dropout = 0.1
|
||||||
cfg.model.horizon = 12
|
cfg.model.tfp.horizon = 12
|
||||||
cfg.model.input_dim = 1 # If 0, model will be built by data.shape
|
cfg.model.tfp.input_dim = 1 # If 0, model will be built by data.shape
|
||||||
cfg.model.output_dim = 1
|
cfg.model.tfp.output_dim = 1
|
||||||
cfg.model.embed_dim = 10
|
cfg.model.tfp.embed_dim = 10
|
||||||
cfg.model.num_layers = 1 # In GPR-GNN, K = layer
|
cfg.model.tfp.num_layers = 1 # In GPR-GNN, K = layer
|
||||||
cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
|
cfg.model.tfp.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
|
||||||
cfg.model.use_day = True
|
cfg.model.tfp.use_day = True
|
||||||
cfg.model.use_week = True
|
cfg.model.tfp.use_week = True
|
||||||
|
cfg.model.tfp.minigraph = CN()
|
||||||
|
cfg.model.tfp.minigraph.enable = False
|
||||||
|
cfg.model.tfp.minigraph.size = 5
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# Criterion related options
|
# Criterion related options
|
||||||
|
|
|
||||||
|
|
@ -125,22 +125,22 @@ def load_traffic_data(config, client_cfgs):
|
||||||
|
|
||||||
# normalize st data
|
# normalize st data
|
||||||
normalizer = 'std'
|
normalizer = 'std'
|
||||||
scaler = normalize_dataset(x_train[..., :config.model.input_dim], normalizer, config.data.column_wise)
|
scaler = normalize_dataset(x_train[..., :config.model.tfp.input_dim], normalizer, config.data.column_wise)
|
||||||
config.data.scaler = [float(scaler.mean), float(scaler.std)]
|
config.data.scaler = [float(scaler.mean), float(scaler.std)]
|
||||||
|
|
||||||
x_train[..., :config.model.input_dim] = scaler.transform(x_train[..., :config.model.input_dim])
|
x_train[..., :config.model.tfp.input_dim] = scaler.transform(x_train[..., :config.model.tfp.input_dim])
|
||||||
x_val[..., :config.model.input_dim] = scaler.transform(x_val[..., :config.model.input_dim])
|
x_val[..., :config.model.tfp.input_dim] = scaler.transform(x_val[..., :config.model.tfp.input_dim])
|
||||||
x_test[..., :config.model.input_dim] = scaler.transform(x_test[..., :config.model.input_dim])
|
x_test[..., :config.model.tfp.input_dim] = scaler.transform(x_test[..., :config.model.tfp.input_dim])
|
||||||
# y_train[..., :config.model.output_dim] = scaler.transform(y_train[..., :config.model.output_dim])
|
# y_train[..., :config.model.tfp.output_dim] = scaler.transform(y_train[..., :config.model.tfp.output_dim])
|
||||||
# y_val[..., :config.model.output_dim] = scaler.transform(y_val[..., :config.model.output_dim])
|
# y_val[..., :config.model.tfp.output_dim] = scaler.transform(y_val[..., :config.model.tfp.output_dim])
|
||||||
# y_test[..., :config.model.output_dim] = scaler.transform(y_test[..., :config.model.output_dim])
|
# y_test[..., :config.model.tfp.output_dim] = scaler.transform(y_test[..., :config.model.tfp.output_dim])
|
||||||
|
|
||||||
# Client-side dataset splitting
|
# Client-side dataset splitting
|
||||||
node_num = config.data.num_nodes
|
node_num = config.data.num_nodes
|
||||||
client_num = config.federate.client_num
|
client_num = config.federate.client_num
|
||||||
per_samples = node_num // client_num
|
per_samples = node_num // client_num
|
||||||
data_list, cur_index = dict(), 0
|
data_list, cur_index = dict(), 0
|
||||||
input_dim, output_dim = config.model.input_dim, config.model.output_dim
|
input_dim, output_dim = config.model.tfp.input_dim, config.model.tfp.output_dim
|
||||||
for i in range(client_num):
|
for i in range(client_num):
|
||||||
if cur_index + per_samples <= node_num:
|
if cur_index + per_samples <= node_num:
|
||||||
# Normal slicing
|
# Normal slicing
|
||||||
|
|
@ -156,7 +156,7 @@ def load_traffic_data(config, client_cfgs):
|
||||||
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_train = x_train[:, :, cur_index:cur_index + per_samples, :]
|
||||||
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_val = x_val[:, :, cur_index:cur_index + per_samples, :]
|
||||||
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
sub_array_test = x_test[:, :, cur_index:cur_index + per_samples, :]
|
||||||
padding = np.zeros((x_train.shape[0], config.data.lag ,config.data.lag, per_samples - x_train.shape[1], config.model.output_dim))
|
padding = np.zeros((x_train.shape[0], config.data.lag ,config.data.lag, per_samples - x_train.shape[1], config.model.tfp.output_dim))
|
||||||
sub_array_train = np.concatenate((sub_array_train, padding), axis=2)
|
sub_array_train = np.concatenate((sub_array_train, padding), axis=2)
|
||||||
sub_array_val = np.concatenate((sub_array_val, padding), axis=2)
|
sub_array_val = np.concatenate((sub_array_val, padding), axis=2)
|
||||||
sub_array_test = np.concatenate((sub_array_test, padding), axis=2)
|
sub_array_test = np.concatenate((sub_array_test, padding), axis=2)
|
||||||
|
|
@ -185,7 +185,7 @@ def load_traffic_data(config, client_cfgs):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
cur_index += per_samples
|
cur_index += per_samples
|
||||||
config.model.num_nodes = per_samples
|
config.model.tfp.num_nodes = per_samples
|
||||||
return data_list, config
|
return data_list, config
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,33 +40,33 @@ class DGCRM(nn.Module):
|
||||||
class FedDGCN(nn.Module):
|
class FedDGCN(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(FedDGCN, self).__init__()
|
super(FedDGCN, self).__init__()
|
||||||
self.num_node = args.num_nodes
|
self.num_node = args.tfp.num_nodes
|
||||||
self.input_dim = args.input_dim
|
self.input_dim = args.tfp.input_dim
|
||||||
self.hidden_dim = args.rnn_units
|
self.hidden_dim = args.tfp.rnn_units
|
||||||
self.output_dim = args.output_dim
|
self.output_dim = args.tfp.output_dim
|
||||||
self.horizon = args.horizon
|
self.horizon = args.tfp.horizon
|
||||||
self.num_layers = args.num_layers
|
self.num_layers = args.tfp.num_layers
|
||||||
self.use_D = args.use_day
|
self.use_D = args.tfp.use_day
|
||||||
self.use_W = args.use_week
|
self.use_W = args.tfp.use_week
|
||||||
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
|
self.dropout1 = nn.Dropout(p=args.tfp.dropout) # 0.1
|
||||||
self.dropout2 = nn.Dropout(p=args.dropout)
|
self.dropout2 = nn.Dropout(p=args.tfp.dropout)
|
||||||
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.tfp.embed_dim), requires_grad=True)
|
||||||
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.tfp.embed_dim), requires_grad=True)
|
||||||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.tfp.embed_dim))
|
||||||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.tfp.embed_dim))
|
||||||
# Initialize parameters
|
# Initialize parameters
|
||||||
nn.init.xavier_uniform_(self.node_embeddings1)
|
nn.init.xavier_uniform_(self.node_embeddings1)
|
||||||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||||
|
|
||||||
self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
self.encoder1 = DGCRM(args.tfp.num_nodes, args.tfp.input_dim, args.tfp.rnn_units, args.tfp.cheb_order,
|
||||||
args.embed_dim, args.num_layers)
|
args.tfp.embed_dim, args.tfp.num_layers)
|
||||||
self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
self.encoder2 = DGCRM(args.tfp.num_nodes, args.tfp.input_dim, args.tfp.rnn_units, args.tfp.cheb_order,
|
||||||
args.embed_dim, args.num_layers)
|
args.tfp.embed_dim, args.tfp.num_layers)
|
||||||
# predictor
|
# predictor
|
||||||
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
self.end_conv1 = nn.Conv2d(1, args.tfp.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
self.end_conv2 = nn.Conv2d(1, args.tfp.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
self.end_conv3 = nn.Conv2d(1, args.tfp.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
|
|
||||||
def forward(self, source, i=2):
|
def forward(self, source, i=2):
|
||||||
node_embedding1 = self.node_embeddings1
|
node_embedding1 = self.node_embeddings1
|
||||||
|
|
@ -107,10 +107,10 @@ def ModelBuilder(model_config, local_data):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def call_ddgcrn(model_config, local_data):
|
def call_feddgcn(model_config, local_data):
|
||||||
if model_config.type == "DDGCRN":
|
if model_config.type == "FedDGCN":
|
||||||
model = ModelBuilder(model_config, local_data)
|
model = ModelBuilder(model_config, local_data)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
register_model("DDGCRN", call_ddgcrn)
|
register_model("FedDGCN", call_feddgcn)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use_gpu: True
|
use_gpu: False
|
||||||
seed: 10
|
seed: 10
|
||||||
device: 0
|
device: 0
|
||||||
early_stop:
|
early_stop:
|
||||||
|
|
@ -33,17 +33,21 @@ dataloader:
|
||||||
model:
|
model:
|
||||||
type: FedDGCN
|
type: FedDGCN
|
||||||
task: TrafficFlowPrediction
|
task: TrafficFlowPrediction
|
||||||
dropout: 0.1
|
tfp:
|
||||||
horizon: 12
|
dropout: 0.1
|
||||||
num_nodes: 0
|
horizon: 12
|
||||||
input_dim: 1
|
num_nodes: 0
|
||||||
output_dim: 1
|
input_dim: 1
|
||||||
embed_dim: 10
|
output_dim: 1
|
||||||
rnn_units: 64
|
embed_dim: 10
|
||||||
num_layers: 1
|
rnn_units: 64
|
||||||
cheb_order: 2
|
num_layers: 1
|
||||||
use_day: True
|
cheb_order: 2
|
||||||
use_week: True
|
use_day: True
|
||||||
|
use_week: True
|
||||||
|
minigraph:
|
||||||
|
enable: True
|
||||||
|
size: 5
|
||||||
train:
|
train:
|
||||||
batch_or_epoch: 'epoch'
|
batch_or_epoch: 'epoch'
|
||||||
local_update_steps: 1
|
local_update_steps: 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue