Skip to content

Commit a3a6fcb

Browse files
Vincent GuoVincent Guo
authored andcommitted
Troubleshooting 6/6 microsoft#6
1 parent 8220730 commit a3a6fcb

File tree

1 file changed

+13
-0
lines changed
  • src/climax/regional_forecast

1 file changed

+13
-0
lines changed

src/climax/regional_forecast/arch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables,
2929

3030
# get the patch ids corresponding to the region
3131
region_patch_ids = region_info['patch_ids']
32+
33+
34+
######### LOOOK HEREEEEEEEEEE ############
35+
36+
print("x.shape:", x.shape)
37+
print("region_patch_ids:", region_patch_ids)
38+
print("max region_patch_ids:", region_patch_ids.max().item())
39+
print("x.shape[2]:", x.shape[2])
40+
41+
assert region_patch_ids.max() < x.shape[2], \
42+
f"region_patch_ids max index {region_patch_ids.max()} exceeds x.shape[2] {x.shape[2]}"
43+
####### TILL HERE #####################
44+
3245
x = x[:, :, region_patch_ids, :]
3346

3447
# variable aggregation

0 commit comments

Comments
 (0)