Skip to content

Commit 1b7c3af

Browse files
add pems training loop
1 parent 9cc98cf commit 1b7c3af

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
from torch_geometric_temporal.dataset import PemsDatasetLoader
6+
from torch_geometric_temporal.nn.recurrent import TGCN2
7+
import argparse
8+
import time
9+
10+
11+
def parse_arguments():
12+
parser = argparse.ArgumentParser(description="Demo of index batching with PemsBay dataset")
13+
14+
parser.add_argument(
15+
"-e", "--epochs", type=int, default=100, help="The desired number of training epochs"
16+
)
17+
parser.add_argument(
18+
"-bs", "--batch-size", type=int, default=64, help="The desired batch size"
19+
)
20+
parser.add_argument(
21+
"-g", "--gpu", type=str, default="False", help="Should data be preprocessed and migrated directly to the GPU"
22+
)
23+
parser.add_argument(
24+
"-d", "--debug", type=str, default="False", help="Print values for debugging"
25+
)
26+
return parser.parse_args()
27+
28+
# --- Model ---
29+
class BatchedTGCN(nn.Module):
30+
def __init__(self, in_channels, hidden_dim, out_channels):
31+
super().__init__()
32+
self.tgnn = TGCN2(in_channels, hidden_dim, 1)
33+
self.linear = nn.Linear(hidden_dim, out_channels)
34+
35+
def forward(self, x, edge_index, edge_weight):
36+
# x: [B, N, F, T]
37+
B, N, Fin, T = x.shape
38+
39+
h = None
40+
output_sequence = []
41+
for t in range(T):
42+
h = self.tgnn(x[..., t], edge_index, edge_weight, h) # h: [B, N, hidden_dim]
43+
h_t = F.relu(h)
44+
out_t = self.linear(h_t).unsqueeze(1) # [B, N, output_dim] → [B, 1, N, output_dim]
45+
output_sequence.append(out_t)
46+
47+
return torch.cat(output_sequence, dim=1) # [B, T, N, output_dim]
48+
49+
def masked_mae_loss(y_pred, y_true):
50+
mask = (y_true != 0).float()
51+
mask /= mask.mean()
52+
loss = torch.abs(y_pred - y_true)
53+
loss = loss * mask
54+
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
55+
loss[loss != loss] = 0
56+
return loss.mean()
57+
58+
def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_index, edge_weight, DEVICE, allGPU=False, debug=False):
59+
# currently predicting speed and time of day. This can be changed to just predict speed.
60+
model = BatchedTGCN(in_channels=2, out_channels=2, hidden_dim=32, ).to(DEVICE)
61+
62+
if not allGPU:
63+
mean = mean.to(DEVICE)
64+
std = std.to(DEVICE)
65+
66+
edge_index = edge_index.to(DEVICE)
67+
edge_weight = edge_weight.to(DEVICE)
68+
69+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
70+
71+
t_maes = []
72+
v_maes = []
73+
74+
for epoch in range(epochs):
75+
model.train()
76+
epoch_loss = []
77+
i = 1
78+
total = len(train_dataloader)
79+
t1 = time.time()
80+
for x, y in train_dataloader:
81+
82+
if allGPU:
83+
x = x.permute(0,2,3,1).float()
84+
y = y.float()
85+
else:
86+
x = x.permute(0,2,3,1).to(DEVICE).float()
87+
y = y.to(DEVICE).float()
88+
89+
y_hat = model(x, edge_index, edge_weight).squeeze()
90+
loss = masked_mae_loss((y_hat * std) + mean, (y * std) + mean)
91+
92+
loss.backward()
93+
optimizer.step()
94+
optimizer.zero_grad()
95+
epoch_loss.append(loss.item())
96+
if debug:
97+
print(f"Train Batch: {i}/{total}", end="\r")
98+
i+=1
99+
100+
if debug:
101+
print(" ", end="\r")
102+
model.eval()
103+
test_loss = []
104+
i = 1
105+
total = len(val_dataloader)
106+
with torch.no_grad():
107+
for x, y in val_dataloader:
108+
if allGPU:
109+
x = x.permute(0,2,3,1).float()
110+
y = y.float()
111+
else:
112+
x = x.permute(0,2,3,1).to(DEVICE).float()
113+
y = y.to(DEVICE).float()
114+
115+
y_hat = model(x, edge_index, edge_weight).squeeze()
116+
117+
loss = masked_mae_loss((y_hat * std) + mean, (y * std) + mean)
118+
test_loss.append(loss.item())
119+
120+
if debug:
121+
print(f"Test Batch: {i}/{total}", end="\r")
122+
i+=1
123+
124+
125+
t_maes.append(np.mean(epoch_loss))
126+
v_maes.append(np.mean(test_loss))
127+
t2 = time.time()
128+
129+
print(f"Epoch {epoch + 1}/{epochs}, Runtime: {t2 - t1}, Train Loss: {np.mean(epoch_loss):.4f}, Val Loss: {np.mean(test_loss):.4f}", flush=True)
130+
131+
return min(t_maes).item(), min(v_maes).item()
132+
133+
def main():
134+
args = parse_arguments()
135+
allGPU = args.gpu.lower() in ["true", "y", "t", "yes"]
136+
debug = args.debug.lower() in ["true", "y", "t", "yes"]
137+
batch_size = args.batch_size
138+
epochs = args.epochs
139+
140+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141+
shuffle= True
142+
143+
start = time.time()
144+
p1 = time.time()
145+
indexLoader = PemsDatasetLoader(index=True)
146+
if allGPU:
147+
train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle, allGPU=0)
148+
else:
149+
train_dataloader, val_dataloader, test_dataloader, edges, edge_weights, mean, std = indexLoader.get_index_dataset(batch_size=batch_size, shuffle=shuffle)
150+
p2 = time.time()
151+
t_mse, v_mse = train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edges, edge_weights, device, debug=debug)
152+
end = time.time()
153+
154+
print(f"Runtime: {round(end - start,2)}; T-MAE: {round(t_mse, 5)}; V-MAE: {round(v_mse, 5)}")
155+
156+
if __name__ == "__main__":
157+
main()

0 commit comments

Comments
 (0)