@@ -185,7 +185,7 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
185185 def forward (self , x , M , T , H , W , padded_days_mask = None ):
186186 """
187187 Args:
188- x: (B, M*T*H* W, C) containing spatio-temporal tokens, where C is the embedding dimension.
188+ x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension.
189189 M: number of months
190190 T: number of temporal tokens per month after temporal patching (Tp)
191191 H: spatial height after spatial patching
@@ -194,9 +194,12 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
194194 True indicating which day tokens are padded (because some months
195195 have fewer days). This is used to mask out padded tokens in attention computation.
196196 Returns:
197- Tensor of shape (B, M* H*W, C) with one temporally aggregated, where C is the embedding dimension.
197+ Tensor of shape (B, M, H*W, C) with one temporally aggregated, where C is the embedding dimension.
198198 """
199- seq = rearrange (x , "b (m t h w) c -> b (h w) m t c" , m = M , t = T , h = H , w = W )
199+ B , M , Tp , Hp , Wp , C = x .shape
200+
201+ # Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing
202+ seq = x .permute (0 , 3 , 4 , 1 , 2 , 5 ).reshape (B , Hp * Wp , M , Tp , C )
200203
201204 pe_days = self .pos_days (T ).to (seq .device ).to (seq .dtype ) # (T, C)
202205 pe_months = self .pos_months (M ).to (seq .device ).to (seq .dtype ) # (M, C)
@@ -209,10 +212,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
209212
210213 # padded_days_mask is (B, M, T) true=padded, -> (B, HW, M, T)
211214 if padded_days_mask is not None :
212- pad = padded_days_mask [:, None , :, :].expand (x . shape [ 0 ] , H * W , M , T )
215+ pad = padded_days_mask [:, None , :, :].expand (B , H * W , M , T )
213216 day_logits = day_logits .masked_fill (pad , float ("-inf" ))
214217
215- day_w = torch .softmax (day_logits , dim = - 1 )
218+ day_w = torch .softmax (day_logits , dim = - 1 ) # turns inf to 0
216219 month_tokens = (seq * day_w .unsqueeze (- 1 )).sum (dim = 3 ) # (B, HW, M, C)
217220
218221 # Cross-month attention at each spatial location
@@ -222,10 +225,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
222225 z = z + attn_out
223226 z = z + self .month_ffn (z )
224227
225- # Back to flattened tokens with month kept
226- z = rearrange ( z , "(b s) m c -> b s m c" , b = x . shape [ 0 ], s = H * W )
227- out = rearrange ( z , "b (h w) m c -> b (m h w) c" , h = H , w = W )
228- return out # (B, M* H*W, C) C: embedding dimension
228+ # Back to (B, M, Hp*Wp, C)
229+ z = z . view ( B , Hp * Wp , M , C )
230+ out = z . permute ( 0 , 2 , 1 , 3 ) # (B, M, Hp*Wp, C )
231+ return out # (B, M, H*W, C) C: embedding dimension
229232
230233
231234class MonthlyConvDecoder (nn .Module ):
@@ -293,10 +296,10 @@ def __init__(
293296 # Refinement block: a small conv layers to smooth patch boundaries
294297 self .refine = nn .Sequential (
295298 nn .Conv2d (in_channels , out_channels , kernel_size = 3 , padding = 1 ),
296- nn .BatchNorm2d ( out_channels ),
299+ nn .GroupNorm ( num_groups = 8 , num_channels = out_channels ),
297300 nn .GELU (),
298301 nn .Conv2d (out_channels , out_channels , kernel_size = 3 , padding = 1 ),
299- nn .BatchNorm2d ( out_channels ),
302+ nn .GroupNorm ( num_groups = 8 , num_channels = out_channels ),
300303 nn .GELU (),
301304 )
302305
@@ -314,18 +317,18 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):
314317 M: Number of months (temporal patches)
315318 out_H: Target output height (must be divisible by patch_h)
316319 out_W: Target output width (must be divisible by patch_w)
317- land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True
320+ land_mask: Optional boolean tensor of shape (B, out_H, out_W). Values set to True
318321 will be masked out (set to 0) in the output (only ocean pixels exist).
319322 Returns:
320323 Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST.
321324 """
322- B , MHW , C = latent .shape
325+ B , M , Np , C = latent .shape
323326 Hp = out_H // self .patch_h
324327 Wp = out_W // self .patch_w
325- assert MHW == M * Hp * Wp , f"Token mismatch: got { MHW } , expected { M * Hp * Wp } "
328+ assert Np == Hp * Wp , f"Token mismatch: got { Np } , expected { Hp * Wp } "
326329
327330 # transforms the latent tensor from sequence format to image format for
328- # convolution operations; (B, M*Hp*Wp, C) -> (B*M, C, Hp, Wp)
331+ # convolution operations;
329332 out = latent .view (B , M , Hp , Wp , C ).permute (0 , 1 , 4 , 2 , 3 ).contiguous ()
330333 out = out .view (B * M , C , Hp , Wp )
331334
@@ -349,7 +352,7 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):
349352
350353 # Mask out land areas if land_mask is provided
351354 if land_mask is not None :
352- out = out .masked_fill (land_mask .bool ()[None , None , :, :], 0.0 )
355+ out = out .masked_fill (land_mask .bool ()[: , None , :, :], 0.0 )
353356 return out # (B, M, out_H, out_W)
354357
355358
@@ -500,10 +503,11 @@ def __init__(
500503 patch_size = (1 , 4 , 4 ),
501504 max_days = 31 ,
502505 max_months = 12 ,
503- hidden = 128 ,
506+ num_months = 12 ,
507+ hidden = 256 ,
504508 overlap = 1 ,
505- max_H = 1024 ,
506- max_W = 1024 ,
509+ max_H = 256 ,
510+ max_W = 256 ,
507511 spatial_depth = 2 ,
508512 spatial_heads = 4 ,
509513 ):
@@ -515,6 +519,7 @@ def __init__(
515519 patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching
516520 max_days: Maximum number of days for temporal positional encoding
517521 max_months: Maximum number of months for temporal positional encoding
522+ num_months: Number of months to predict (output channels in decoder)
518523 hidden: Hidden dimension used in the decoder
519524 overlap: Overlap for deconvolution in the decoder
520525 max_H: Maximum spatial height for 2D positional encoding
@@ -541,7 +546,7 @@ def __init__(
541546 patch_w = patch_size [2 ],
542547 hidden = hidden ,
543548 overlap = overlap ,
544- num_months = max_months ,
549+ num_months = num_months ,
545550 )
546551 self .patch_size = patch_size
547552
@@ -552,7 +557,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
552557 daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
553558 data, where C is the number of channels (e.g., 1 for SST)
554559 daily_mask: Boolean tensor of same shape as daily_data indicating missing values
555- land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output
560+ land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output
556561 padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
557562 (True for padded tokens). Used to mask out padded tokens in temporal attention.
558563 Returns:
@@ -574,18 +579,6 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
574579 )
575580 assert T % self .patch_size [0 ] == 0 , "T must be divisible by patch size"
576581
577- # Step 1: Encode spatio-temporal patches
578- # each month independently by folding M into batch
579- daily_data_reshaped = daily_data .reshape (B * M , C , T , H , W )
580- daily_mask_reshaped = daily_mask .reshape (B * M , C , T , H , W )
581- latent = self .encoder (
582- daily_data_reshaped , daily_mask_reshaped
583- ) # (B*M, N_patches, embed_dim)
584-
585- # Step 2: Aggregate temporal information for each spatial patch
586- # latent -> (B, M*Np, embed_dim) to match the aggregator input x: (B, M*Tp*Hp*Wp, embed_dim)
587- latent = latent .reshape (B , M * Np , - 1 )
588-
589582 if padded_days_mask is not None and self .patch_size [0 ] > 1 :
590583 B , M , T_days = padded_days_mask .shape
591584 if T_days % self .patch_size [0 ] != 0 :
@@ -596,23 +589,45 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
596589 B , M , T_days // self .patch_size [0 ], self .patch_size [0 ]
597590 ).all (dim = - 1 ) # (B, M, Tp)
598591
592+ # Step 1: Encode spatio-temporal patches
593+ # each month independently by folding M into batch
594+ # encoder input shape = (B, C, T, H, W) where C is channel.
595+ # encoder output shape = (B, N_patches, embed_dim)
596+ # so M is folded into B, and T, H, W are the spatio-temporal dimensions to be patched.
597+ daily_data_reshaped = daily_data .reshape (B * M , C , T , H , W )
598+ daily_mask_reshaped = daily_mask .reshape (B * M , C , T , H , W )
599+
600+ latent = self .encoder (
601+ daily_data_reshaped , daily_mask_reshaped
602+ ) # (B*M, N_patches, embed_dim)
603+
604+ # Step 2: Aggregate temporal information for each spatial patch
605+ # temporal input shape = (B, M*T*H*W, C), C: embedding dimension
606+ # temporal output shape = (B, M, H*W, C) C: embedding dimension
607+ embed_dim = latent .shape [- 1 ]
608+ latent = latent .view (B , M , Tp , Hp , Wp , embed_dim )
609+
599610 agg_latent = self .temporal (
600611 latent , M , Tp , Hp , Wp , padded_days_mask = padded_days_mask
601- ) # (B, M* Hp*Wp, embed_dim)
612+ ) # (B, M, Hp*Wp, embed_dim)
602613
603614 # Step 3: Add spatial positional encodings and mix spatial features
604- E = agg_latent .shape [- 1 ]
605- agg_latent = agg_latent .view (B , M , Hp * Wp , E )
615+ # spatial PE output shape = (Hp, Wp, embed_dim)
606616 pe = (
607617 self .spatial_pe (Hp , Wp ).to (agg_latent .device ).to (agg_latent .dtype )
608- ) # (Hp* Wp, E)
609- x = agg_latent + pe [None , None , :, :]
618+ ) # (Hp, Wp, E)
619+ x = agg_latent + pe [None , None , :, :] # (B, M, Hp*Wp, E)
610620
611621 # Step 4: Spatial mixing with Transformer
612- x = x .view (B * M , Hp * Wp , E )
613- x = self .spatial_tr (x ) # (B*M, Hp*Wp, E)
614- x = x .view (B , M * Hp * Wp , E ) # back to (B, M*Hp*Wp, E)
622+ # spatial transformer input shape = (B, N, C), output shape = (B, N, C) C: embedding dimension
623+ # M is folded in B.
624+ C = x .shape [- 1 ]
625+ x = x .reshape (B * M , Hp * Wp , C )
626+ x = self .spatial_tr (x )
627+ x = x .view (B , M , Hp * Wp , C )
615628
616629 # Step 5: Decode to full-resolution 2D map
630+ # decoder input shape is (B, M*Hp*Wp, C), C: embedding dimension
631+ # decoder output shape is (B, M, H, W)
617632 monthly_pred = self .decoder (x , M , H , W , land_mask_patch ) # (B, M, H, W)
618633 return monthly_pred
0 commit comments