diff --git a/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py b/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py index bba354c..84f6df0 100644 --- a/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py +++ b/federatedscope/trafficflow/dataloader/traffic_dataloader_v2.py @@ -188,35 +188,34 @@ def load_traffic_data(config, client_cfgs): return data_list, config -def split_into_mini_graphs(tensor, minigraph_size): +def split_into_mini_graphs(tensor, graph_size, dummy_value=0): """ Splits a tensor into mini-graphs of specified size. Pads the last mini-graph with dummy nodes if necessary. Args: tensor (np.ndarray): Input tensor with shape (timestep, horizon, node_num, dim). - minigraph_size (int): The size of each mini-graph. + graph_size (int): The size of each mini-graph. + dummy_value (float, optional): The value to use for dummy nodes. Default is 0. Returns: - np.ndarray: Output tensor with shape (timestep, horizon, minigraph_num, minigraph_size, dim). + np.ndarray: Output tensor with shape (timestep, horizon, graph_num, graph_size, dim). """ timestep, horizon, node_num, dim = tensor.shape # Calculate the number of mini-graphs - minigraph_num = (node_num + minigraph_size - 1) // minigraph_size # Round up division + graph_num = (node_num + graph_size - 1) // graph_size # Round up division - # Initialize output tensor with zeros (dummy nodes) - output = np.zeros((timestep, horizon, minigraph_num, minigraph_size, dim), dtype=tensor.dtype) + # Initialize output tensor with dummy values + output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype) # Fill in the real data - for i in range(minigraph_num): - start_idx = i * minigraph_size - end_idx = min(start_idx + minigraph_size, node_num) # Ensure we don't exceed the node number + for i in range(graph_num): + start_idx = i * graph_size + end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number slice_size = end_idx - start_idx - # Assign the data to the corresponding mini-graph - output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :] - - # For the remaining part in the mini-graph, it remains as dummy nodes (zeros) + # Assign the data to the corresponding mini-graph (adjusted indexing) + output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :] return output diff --git a/scripts/trafficflow_exp_scripts/D4.yaml b/scripts/trafficflow_exp_scripts/D4.yaml index fe0e8f9..73d9f70 100644 --- a/scripts/trafficflow_exp_scripts/D4.yaml +++ b/scripts/trafficflow_exp_scripts/D4.yaml @@ -45,7 +45,7 @@ model: use_day: True use_week: True use_minigraph: True - minigraph_size: 5 + minigraph_size: 3 train: batch_or_epoch: 'epoch' local_update_steps: 1