1+ import numpy as np
2+ import networkx as nx
3+
4+ import torch
5+ from torch_geometric_temporal .signal import temporal_signal_split
6+
7+ from torch_geometric_temporal .signal import StaticGraphTemporalSignal
8+ from torch_geometric_temporal .signal import DynamicGraphTemporalSignal
9+ from torch_geometric_temporal .signal import DynamicGraphStaticSignal
10+
11+ from torch_geometric_temporal .signal import StaticHeteroGraphTemporalSignal
12+ from torch_geometric_temporal .signal import DynamicHeteroGraphTemporalSignal
13+ from torch_geometric_temporal .signal import DynamicHeteroGraphStaticSignal
14+
15+ from torch_geometric_temporal .dataset import (METRLADatasetLoader , PemsBayDatasetLoader ,
16+ WindmillOutputLargeDatasetLoader , ChickenpoxDatasetLoader )
17+
18+ def test_index_metrla ():
19+ loader = METRLADatasetLoader (raw_data_dir = "/tmp/" )
20+ dataset = loader .get_dataset (num_timesteps_in = 6 , num_timesteps_out = 6 )
21+
22+ indexLoader = METRLADatasetLoader (raw_data_dir = "/tmp/" ,index = True )
23+ train_dataloader , _ ,_ , edges , edge_weights , _ , _ = indexLoader .get_index_dataset (batch_size = 1 , shuffle = False , lags = 6 )
24+
25+ for epoch in range (2 ):
26+ for snapshot , indexed_batch in zip (dataset , train_dataloader ):
27+ x ,y = indexed_batch
28+ x = torch .squeeze (x ).permute (1 ,2 ,0 )
29+ y = torch .squeeze (y )[...,0 ].permute (1 ,0 )
30+
31+ assert torch .equal (snapshot .x ,x )
32+ assert torch .equal (snapshot .y ,y )
33+
34+ assert torch .equal (snapshot .edge_index ,edges )
35+ assert torch .equal (snapshot .edge_attr ,edge_weights )
36+
37+ assert edges .shape == (2 , 1722 )
38+ assert edge_weights .shape == (1722 ,)
39+ assert x .shape == (207 , 2 , 6 )
40+ assert y .shape == (207 , 6 )
41+
42+ def test_index_pemsbay ():
43+
44+ loader = PemsBayDatasetLoader (raw_data_dir = "/tmp/" )
45+ dataset = loader .get_dataset ()
46+
47+ indexLoader = PemsBayDatasetLoader (raw_data_dir = "/tmp/" ,index = True )
48+ train_dataloader , _ ,_ , edges , edge_weights , _ , _ = indexLoader .get_index_dataset (batch_size = 1 , shuffle = False )
49+
50+ for epoch in range (2 ):
51+
52+ for snapshot , indexed_batch in zip (dataset , train_dataloader ):
53+ x ,y = indexed_batch
54+ x = torch .squeeze (x ).permute (1 ,2 ,0 )
55+ y = torch .squeeze (y ).permute (1 ,2 ,0 )
56+
57+ assert torch .equal (snapshot .x ,x )
58+ assert torch .equal (snapshot .y ,y )
59+
60+ assert torch .equal (snapshot .edge_index ,edges )
61+ assert torch .equal (snapshot .edge_attr ,edge_weights )
62+
63+ assert edges .shape == (2 , 2694 )
64+ assert edge_weights .shape == (2694 ,)
65+ assert x .shape == (325 , 2 , 12 )
66+ assert y .shape == (325 , 2 , 12 )
67+
68+ def test_index_windmilllarge ():
69+
70+ loader = WindmillOutputLargeDatasetLoader (raw_data_dir = "/tmp/" )
71+ dataset = loader .get_dataset ()
72+
73+ indexLoader = WindmillOutputLargeDatasetLoader (raw_data_dir = "/tmp/" ,index = True )
74+ train_dataloader , _ ,_ , edges , edge_weights , _ , _ = indexLoader .get_index_dataset (batch_size = 1 , shuffle = False )
75+
76+ for epoch in range (2 ):
77+ for snapshot , indexed_batch in zip (dataset , train_dataloader ):
78+ x ,y = indexed_batch
79+ x = torch .squeeze (x ).permute (1 ,0 ).float ()
80+ y = torch .squeeze (y ).permute (1 ,0 ).float ()[...,0 ]
81+
82+ assert torch .equal (snapshot .x ,x )
83+ assert torch .equal (snapshot .y ,y )
84+
85+ assert torch .equal (snapshot .edge_index ,edges )
86+ assert torch .equal (snapshot .edge_attr ,edge_weights )
87+
88+ assert edges .shape == (2 , 101761 )
89+ assert edge_weights .shape == (101761 ,)
90+ assert x .shape == (319 , 8 )
91+ assert y .shape == (319 ,)
92+
93+ def test_index_chickenpox ():
94+ loader = ChickenpoxDatasetLoader ()
95+ dataset = loader .get_dataset ()
96+
97+ indexLoader = ChickenpoxDatasetLoader (index = True )
98+ train_dataloader , _ ,_ , edges , edge_weights = indexLoader .get_index_dataset (batch_size = 1 , shuffle = False )
99+
100+ for epoch in range (2 ):
101+ for snapshot , indexed_batch in zip (dataset , train_dataloader ):
102+ x ,y = indexed_batch
103+ x = torch .squeeze (x ).permute (1 ,0 ).float ()
104+ y = torch .squeeze (y ).float ()[0 ,...]
105+
106+ assert torch .equal (snapshot .x ,x )
107+ assert torch .equal (snapshot .y ,y )
108+
109+ assert torch .equal (snapshot .edge_index ,edges )
110+ assert torch .equal (snapshot .edge_attr ,edge_weights )
111+
112+ assert edges .shape == (2 , 102 )
113+ assert edge_weights .shape == (102 ,)
114+ assert x .shape == (20 , 4 )
115+ assert y .shape == (20 ,)
0 commit comments