Skip to content

Commit 0c1190d

Browse files
add data conversion to np.float32 for exact parity with standard batching
1 parent 5b3090c commit 0c1190d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch_geometric_temporal/dataset/pems_bay.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def get_index_dataset(self, lags: int = 12, batch_size: int = 64, shuffle: bool
186186

187187
# setup data
188188
data = np.load(os.path.join(self.raw_data_dir, "pems_node_values.npy")).transpose((1, 2, 0))
189-
189+
data = data.astype(np.float32)
190+
190191
if allGPU != -1:
191192
data = torch.tensor(data,dtype=torch.float).to(f"cuda:{allGPU}")
192193
means = torch.mean(data, dim=(0, 2), keepdim=True)

0 commit comments

Comments
 (0)