dataloader
This commit is contained in:
parent
22b229891a
commit
b9ca3349c9
|
|
@ -188,35 +188,34 @@ def load_traffic_data(config, client_cfgs):
|
|||
return data_list, config
|
||||
|
||||
|
||||
def split_into_mini_graphs(tensor, minigraph_size):
|
||||
def split_into_mini_graphs(tensor, graph_size, dummy_value=0):
|
||||
"""
|
||||
Splits a tensor into mini-graphs of specified size. Pads the last mini-graph with dummy nodes if necessary.
|
||||
|
||||
Args:
|
||||
tensor (np.ndarray): Input tensor with shape (timestep, horizon, node_num, dim).
|
||||
minigraph_size (int): The size of each mini-graph.
|
||||
graph_size (int): The size of each mini-graph.
|
||||
dummy_value (float, optional): The value to use for dummy nodes. Default is 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Output tensor with shape (timestep, horizon, minigraph_num, minigraph_size, dim).
|
||||
np.ndarray: Output tensor with shape (timestep, horizon, graph_num, graph_size, dim).
|
||||
"""
|
||||
timestep, horizon, node_num, dim = tensor.shape
|
||||
|
||||
# Calculate the number of mini-graphs
|
||||
minigraph_num = (node_num + minigraph_size - 1) // minigraph_size # Round up division
|
||||
graph_num = (node_num + graph_size - 1) // graph_size # Round up division
|
||||
|
||||
# Initialize output tensor with zeros (dummy nodes)
|
||||
output = np.zeros((timestep, horizon, minigraph_num, minigraph_size, dim), dtype=tensor.dtype)
|
||||
# Initialize output tensor with dummy values
|
||||
output = np.full((timestep, horizon, graph_size, graph_num, dim), dummy_value, dtype=tensor.dtype)
|
||||
|
||||
# Fill in the real data
|
||||
for i in range(minigraph_num):
|
||||
start_idx = i * minigraph_size
|
||||
end_idx = min(start_idx + minigraph_size, node_num) # Ensure we don't exceed the node number
|
||||
for i in range(graph_num):
|
||||
start_idx = i * graph_size
|
||||
end_idx = min(start_idx + graph_size, node_num) # Ensure we don't exceed the node number
|
||||
slice_size = end_idx - start_idx
|
||||
|
||||
# Assign the data to the corresponding mini-graph
|
||||
output[:, :, i, :slice_size, :] = tensor[:, :, start_idx:end_idx, :]
|
||||
|
||||
# For the remaining part in the mini-graph, it remains as dummy nodes (zeros)
|
||||
# Assign the data to the corresponding mini-graph (adjusted indexing)
|
||||
output[:, :, :slice_size, i, :] = tensor[:, :, start_idx:end_idx, :]
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ model:
|
|||
use_day: True
|
||||
use_week: True
|
||||
use_minigraph: True
|
||||
minigraph_size: 5
|
||||
minigraph_size: 3
|
||||
train:
|
||||
batch_or_epoch: 'epoch'
|
||||
local_update_steps: 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue