dataloader

This commit is contained in:
HengZhang 2024-11-28 10:12:34 +08:00
parent 22b229891a
commit b9ca3349c9
2 changed files with 13 additions and 14 deletions

View File

@ -188,35 +188,34 @@ def load_traffic_data(config, client_cfgs):
return data_list, config
def split_into_mini_graphs(tensor, minigraph_size):
def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
"""
Splits a tensor into mini-graphs of specified size. Pads the last mini-graph with dummy nodes if necessary.
Args:
tensor (np.ndarray): Input tensor with shape (timestep, horizon, node_num, dim).
minigraph_size (int): The size of each mini-graph.
graph_size (int): The size of each mini-graph.
dummy_value (float, optional): The value to use for dummy nodes. Default is 0.
Returns:
np.ndarray: Output tensor with shape (timestep, horizon, minigraph_num, minigraph_size, dim).
np.ndarray: Output tensor with shape (timestep, horizon, graph_num, graph_size, dim).
"""
timestep, horizon, node_num, dim = tensor.shape
# Calculate the number of mini-graphs
minigraph_num = (node_num + minigraph_size - 1) // minigraph_size # Round up division
graph_num = (node_num + graph_size - 1) // graph_size # Round up division
# Initialize output tensor with zeros (dummy nodes)
output = np.zeros((timestep, horizon, minigraph_num, minigraph_size, dim), dtype=tensor.dtype)
# Initialize output tensor with dummy values
output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype)
# Fill in the real data
for i in range(minigraph_num):
start_idx = i * minigraph_size
end_idx = min(start_idx + minigraph_size, node_num) # Ensure we don't exceed the node number
for i in range(graph_num):
start_idx = i * graph_size
end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number
slice_size = end_idx - start_idx
# Assign the data to the corresponding mini-graph
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
# For the remaining part in the mini-graph, it remains as dummy nodes (zeros)
# Assign the data to the corresponding mini-graph (adjusted indexing)
output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :]
return output

View File

@ -45,7 +45,7 @@ model:
use_day: True
use_week: True
use_minigraph: True
minigraph_size: 5
minigraph_size: 3
train:
batch_or_epoch: 'epoch'
local_update_steps: 1