Skip to content

Commit 2db2dc2

Browse files
add tests for the index batching datasets
1 parent 0c1190d commit 2db2dc2

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

test/index_test.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import numpy as np
2+
import networkx as nx
3+
4+
import torch
5+
from torch_geometric_temporal.signal import temporal_signal_split
6+
7+
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
8+
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal
9+
from torch_geometric_temporal.signal import DynamicGraphStaticSignal
10+
11+
from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignal
12+
from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignal
13+
from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignal
14+
15+
from torch_geometric_temporal.dataset import (METRLADatasetLoader, PemsBayDatasetLoader,
16+
WindmillOutputLargeDatasetLoader, ChickenpoxDatasetLoader)
17+
18+
def test_index_metrla():
19+
loader = METRLADatasetLoader(raw_data_dir="/tmp/")
20+
dataset = loader.get_dataset(num_timesteps_in=6, num_timesteps_out=6)
21+
22+
indexLoader = METRLADatasetLoader(raw_data_dir="/tmp/",index=True)
23+
train_dataloader, _,_, edges, edge_weights, _, _ = indexLoader.get_index_dataset(batch_size=1, shuffle=False, lags=6)
24+
25+
for epoch in range(2):
26+
for snapshot, indexed_batch in zip(dataset, train_dataloader):
27+
x,y = indexed_batch
28+
x = torch.squeeze(x).permute(1,2,0)
29+
y = torch.squeeze(y)[...,0].permute(1,0)
30+
31+
assert torch.equal(snapshot.x,x)
32+
assert torch.equal(snapshot.y,y)
33+
34+
assert torch.equal(snapshot.edge_index,edges)
35+
assert torch.equal(snapshot.edge_attr,edge_weights)
36+
37+
assert edges.shape == (2, 1722)
38+
assert edge_weights.shape == (1722,)
39+
assert x.shape == (207, 2, 6)
40+
assert y.shape == (207, 6)
41+
42+
def test_index_pemsbay():
43+
44+
loader = PemsBayDatasetLoader(raw_data_dir="/tmp/")
45+
dataset = loader.get_dataset()
46+
47+
indexLoader = PemsBayDatasetLoader(raw_data_dir="/tmp/",index=True)
48+
train_dataloader, _,_, edges, edge_weights, _, _ = indexLoader.get_index_dataset(batch_size=1, shuffle=False)
49+
50+
for epoch in range(2):
51+
52+
for snapshot, indexed_batch in zip(dataset, train_dataloader):
53+
x,y = indexed_batch
54+
x = torch.squeeze(x).permute(1,2,0)
55+
y = torch.squeeze(y).permute(1,2,0)
56+
57+
assert torch.equal(snapshot.x,x)
58+
assert torch.equal(snapshot.y,y)
59+
60+
assert torch.equal(snapshot.edge_index,edges)
61+
assert torch.equal(snapshot.edge_attr,edge_weights)
62+
63+
assert edges.shape == (2, 2694)
64+
assert edge_weights.shape == (2694,)
65+
assert x.shape == (325, 2, 12)
66+
assert y.shape == (325, 2, 12)
67+
68+
def test_index_windmilllarge():
69+
70+
loader = WindmillOutputLargeDatasetLoader(raw_data_dir="/tmp/")
71+
dataset = loader.get_dataset()
72+
73+
indexLoader = WindmillOutputLargeDatasetLoader(raw_data_dir="/tmp/",index=True)
74+
train_dataloader, _,_, edges, edge_weights, _, _ = indexLoader.get_index_dataset(batch_size=1, shuffle=False)
75+
76+
for epoch in range(2):
77+
for snapshot, indexed_batch in zip(dataset, train_dataloader):
78+
x,y = indexed_batch
79+
x = torch.squeeze(x).permute(1,0).float()
80+
y = torch.squeeze(y).permute(1,0).float()[...,0]
81+
82+
assert torch.equal(snapshot.x,x)
83+
assert torch.equal(snapshot.y,y)
84+
85+
assert torch.equal(snapshot.edge_index,edges)
86+
assert torch.equal(snapshot.edge_attr,edge_weights)
87+
88+
assert edges.shape == (2, 101761)
89+
assert edge_weights.shape == (101761,)
90+
assert x.shape == (319, 8)
91+
assert y.shape == (319,)
92+
93+
def test_index_chickenpox():
94+
loader = ChickenpoxDatasetLoader()
95+
dataset = loader.get_dataset()
96+
97+
indexLoader = ChickenpoxDatasetLoader(index=True)
98+
train_dataloader, _,_, edges, edge_weights = indexLoader.get_index_dataset(batch_size=1, shuffle=False)
99+
100+
for epoch in range(2):
101+
for snapshot, indexed_batch in zip(dataset, train_dataloader):
102+
x,y = indexed_batch
103+
x = torch.squeeze(x).permute(1,0).float()
104+
y = torch.squeeze(y).float()[0,...]
105+
106+
assert torch.equal(snapshot.x,x)
107+
assert torch.equal(snapshot.y,y)
108+
109+
assert torch.equal(snapshot.edge_index,edges)
110+
assert torch.equal(snapshot.edge_attr,edge_weights)
111+
112+
assert edges.shape == (2, 102)
113+
assert edge_weights.shape == (102,)
114+
assert x.shape == (20, 4)
115+
assert y.shape == (20,)

0 commit comments

Comments
 (0)