Skip to content

Commit 9cc98cf

Browse files
cast to float
1 parent d53fbfb commit 9cc98cf

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

examples/indexBatching/tgcn/pems_all_la_main.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_
6161

6262
edge_index = edge_index.to(DEVICE)
6363
edge_weight = edge_weight.to(DEVICE)
64+
65+
if not allGPU:
66+
mean = mean.to(DEVICE)
67+
std = std.to(DEVICE)
6468

6569
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
6670

@@ -76,10 +80,12 @@ def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_
7680
for x, y in train_dataloader:
7781

7882
if allGPU:
79-
x = x.permute(0,2,3,1)
83+
x = x.permute(0,2,3,1).float()
84+
y = y.float()
85+
8086
else:
81-
x = x.permute(0,2,3,1).to(DEVICE)
82-
y = y.to(DEVICE)
87+
x = x.permute(0,2,3,1).to(DEVICE).float()
88+
y = y.to(DEVICE).float()
8389

8490
y_hat = model(x, edge_index, edge_weight).squeeze()
8591
loss = masked_mae_loss((y_hat * std) + mean, (y * std) + mean)
@@ -101,11 +107,11 @@ def train(train_dataloader, val_dataloader, mean, std, batch_size, epochs, edge_
101107
with torch.no_grad():
102108
for x, y in val_dataloader:
103109
if allGPU:
104-
x = x.permute(0,2,3,1)
105-
# y = y[...,0]
110+
x = x.permute(0,2,3,1).float()
111+
y = y.float()
106112
else:
107-
x = x.permute(0,2,3,1).to(DEVICE)
108-
y = y.to(DEVICE)
113+
x = x.permute(0,2,3,1).to(DEVICE).float()
114+
y = y.to(DEVICE).float()
109115

110116
y_hat = model(x, edge_index, edge_weight).squeeze()
111117

0 commit comments

Comments
 (0)