Skip to content

Commit efa26c4

Browse files
Merge pull request #29 from ESMValGroup/fix_blocky_effect
Fix blocky effect
2 parents 737c138 + 3c5eebd commit efa26c4

6 files changed

Lines changed: 340 additions & 172 deletions

File tree

climanet/dataset.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24
from .utils import add_month_day_dims
35
import xarray as xr
@@ -16,12 +18,10 @@ def __init__(
1618
land_mask: xr.DataArray = None,
1719
time_dim: str = "time",
1820
spatial_dims: Tuple[str, str] = ("lat", "lon"),
19-
patch_size: Tuple[int, int] = (16, 16),
20-
overlap: int = 0,
21+
patch_size: Tuple[int, int] = (16, 16), # (lat, lon)
2122
):
2223
self.spatial_dims = spatial_dims
2324
self.patch_size = patch_size
24-
self.overlap = overlap
2525

2626
# Check that the input data has the expected dimensions
2727
if time_dim not in daily_da.dims or time_dim not in monthly_da.dims:
@@ -30,6 +30,13 @@ def __init__(
3030
if dim not in daily_da.dims or dim not in monthly_da.dims:
3131
raise ValueError(f"Spatial dimension '{dim}' not found in input data")
3232

33+
if (
34+
patch_size[0] > daily_da.sizes[spatial_dims[0]] or patch_size[1] > daily_da.sizes[spatial_dims[1]]
35+
):
36+
raise ValueError(
37+
f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes[spatial_dims]}"
38+
)
39+
3340
# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
3441
# and get padded_days_mask → (M, T=31)
3542
daily_mt, monthly_m, padded_days_mask = add_month_day_dims(
@@ -41,6 +48,10 @@ def __init__(
4148
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
4249
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
4350

51+
# Store coordinate arrays
52+
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
53+
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()
54+
4455
if land_mask is not None:
4556
lm = land_mask.to_numpy().copy()
4657
if lm.ndim == 3:
@@ -60,25 +71,45 @@ def __init__(
6071
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()
6172

6273
# Precompute lazy index mapping for patches
63-
self.stride = self.patch_size[0] - self.overlap
6474
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
65-
self.n_i = (H - self.patch_size[0]) // self.stride + 1
66-
self.n_j = (W - self.patch_size[1]) // self.stride + 1
75+
self.patch_indices = self._compute_patch_indices(H, W)
76+
77+
def _compute_patch_indices(self, H: int, W: int) -> list:
78+
"""Generate non-overlapping patch start indices with coverage warning."""
79+
ph, pw = self.patch_size
80+
81+
# Compute number of full non-overlapping patches
82+
n_patches_h = H // ph
83+
n_patches_w = W // pw
84+
85+
# Check for incomplete coverage
86+
remainder_h = H % ph
87+
remainder_w = W % pw
88+
89+
if remainder_h > 0 or remainder_w > 0:
90+
warnings.warn(
91+
f"Patch size {self.patch_size} does not evenly divide image dimensions (H={H}, W={W}). "
92+
f"Uncovered pixels: {remainder_h} in height, {remainder_w} in width. "
93+
f"Consider adjusting patch_size or image dimensions for full coverage.",
94+
UserWarning
95+
)
96+
97+
# Generate non-overlapping patch indices
98+
i_starts = [i * ph for i in range(n_patches_h)]
99+
j_starts = [j * pw for j in range(n_patches_w)]
100+
101+
return [(i, j) for i in i_starts for j in j_starts]
67102

68-
# Total length is only spatial patches (all months included in each sample)
69-
self.total_len = self.n_i * self.n_j
70103

71104
def __len__(self):
72-
return self.total_len
105+
return len(self.patch_indices)
73106

74107
def __getitem__(self, idx):
75108
"""Get a spatiotemporal patch sample based on the index."""
76-
if idx < 0 or idx >= self.total_len:
109+
if idx < 0 or idx >= len(self.patch_indices):
77110
raise IndexError("Index out of range")
78111

79-
i_idx, j_idx = divmod(idx, self.n_j)
80-
i = i_idx * self.stride
81-
j = j_idx * self.stride
112+
i, j = self.patch_indices[idx]
82113
ph, pw = self.patch_size
83114

84115
# Extract spatial patch via numpy slicing — faster than xarray indexing
@@ -108,6 +139,10 @@ def __getitem__(self, idx):
108139
~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0)
109140
)
110141

142+
# Extract lat/lon coordinates for this patch
143+
lat_patch = self.lat_coords[i : i + ph]
144+
lon_patch = self.lon_coords[j : j + pw]
145+
111146
# Convert to tensors
112147
return {
113148
"daily_patch": daily_tensor, # (C=1, M, T=31, H, W)
@@ -116,4 +151,6 @@ def __getitem__(self, idx):
116151
"land_mask_patch": land_tensor, # (H,W) True=Land
117152
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
118153
"coords": (i, j),
154+
"lat_patch": lat_patch, # (H,)
155+
"lon_patch": lon_patch, # (W,)
119156
}

climanet/st_encoder_decoder.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
185185
def forward(self, x, M, T, H, W, padded_days_mask=None):
186186
"""
187187
Args:
188-
x: (B, M*T*H*W, C) containing spatio-temporal tokens, where C is the embedding dimension.
188+
x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension.
189189
M: number of months
190190
T: number of temporal tokens per month after temporal patching (Tp)
191191
H: spatial height after spatial patching
@@ -194,9 +194,12 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
194194
True indicating which day tokens are padded (because some months
195195
have fewer days). This is used to mask out padded tokens in attention computation.
196196
Returns:
197-
Tensor of shape (B, M*H*W, C) with one temporally aggregated, where C is the embedding dimension.
197+
Tensor of shape (B, M, H*W, C) with one temporally aggregated, where C is the embedding dimension.
198198
"""
199-
seq = rearrange(x, "b (m t h w) c -> b (h w) m t c", m=M, t=T, h=H, w=W)
199+
B, M, Tp, Hp, Wp, C = x.shape
200+
201+
# Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing
202+
seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C)
200203

201204
pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C)
202205
pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C)
@@ -209,10 +212,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
209212

210213
# padded_days_mask is (B, M, T) true=padded, -> (B, HW, M, T)
211214
if padded_days_mask is not None:
212-
pad = padded_days_mask[:, None, :, :].expand(x.shape[0], H * W, M, T)
215+
pad = padded_days_mask[:, None, :, :].expand(B, H * W, M, T)
213216
day_logits = day_logits.masked_fill(pad, float("-inf"))
214217

215-
day_w = torch.softmax(day_logits, dim=-1)
218+
day_w = torch.softmax(day_logits, dim=-1) # turns inf to 0
216219
month_tokens = (seq * day_w.unsqueeze(-1)).sum(dim=3) # (B, HW, M, C)
217220

218221
# Cross-month attention at each spatial location
@@ -222,10 +225,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
222225
z = z + attn_out
223226
z = z + self.month_ffn(z)
224227

225-
# Back to flattened tokens with month kept
226-
z = rearrange(z, "(b s) m c -> b s m c", b=x.shape[0], s=H * W)
227-
out = rearrange(z, "b (h w) m c -> b (m h w) c", h=H, w=W)
228-
return out # (B, M*H*W, C) C: embedding dimension
228+
# Back to (B, M, Hp*Wp, C)
229+
z = z.view(B, Hp * Wp, M, C)
230+
out = z.permute(0, 2, 1, 3) # (B, M, Hp*Wp, C)
231+
return out # (B, M, H*W, C) C: embedding dimension
229232

230233

231234
class MonthlyConvDecoder(nn.Module):
@@ -293,10 +296,10 @@ def __init__(
293296
# Refinement block: a small conv layers to smooth patch boundaries
294297
self.refine = nn.Sequential(
295298
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
296-
nn.BatchNorm2d(out_channels),
299+
nn.GroupNorm(num_groups=8, num_channels=out_channels),
297300
nn.GELU(),
298301
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
299-
nn.BatchNorm2d(out_channels),
302+
nn.GroupNorm(num_groups=8, num_channels=out_channels),
300303
nn.GELU(),
301304
)
302305

@@ -314,18 +317,18 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):
314317
M: Number of months (temporal patches)
315318
out_H: Target output height (must be divisible by patch_h)
316319
out_W: Target output width (must be divisible by patch_w)
317-
land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True
320+
land_mask: Optional boolean tensor of shape (B, out_H, out_W). Values set to True
318321
will be masked out (set to 0) in the output (only ocean pixels exist).
319322
Returns:
320323
Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST.
321324
"""
322-
B, MHW, C = latent.shape
325+
B, M, Np, C = latent.shape
323326
Hp = out_H // self.patch_h
324327
Wp = out_W // self.patch_w
325-
assert MHW == M * Hp * Wp, f"Token mismatch: got {MHW}, expected {M * Hp * Wp}"
328+
assert Np == Hp * Wp, f"Token mismatch: got {Np}, expected {Hp * Wp}"
326329

327330
# transforms the latent tensor from sequence format to image format for
328-
# convolution operations; (B, M*Hp*Wp, C) -> (B*M, C, Hp, Wp)
331+
# convolution operations;
329332
out = latent.view(B, M, Hp, Wp, C).permute(0, 1, 4, 2, 3).contiguous()
330333
out = out.view(B * M, C, Hp, Wp)
331334

@@ -349,7 +352,7 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):
349352

350353
# Mask out land areas if land_mask is provided
351354
if land_mask is not None:
352-
out = out.masked_fill(land_mask.bool()[None, None, :, :], 0.0)
355+
out = out.masked_fill(land_mask.bool()[:, None, :, :], 0.0)
353356
return out # (B, M, out_H, out_W)
354357

355358

@@ -500,10 +503,11 @@ def __init__(
500503
patch_size=(1, 4, 4),
501504
max_days=31,
502505
max_months=12,
503-
hidden=128,
506+
num_months=12,
507+
hidden=256,
504508
overlap=1,
505-
max_H=1024,
506-
max_W=1024,
509+
max_H=256,
510+
max_W=256,
507511
spatial_depth=2,
508512
spatial_heads=4,
509513
):
@@ -515,6 +519,7 @@ def __init__(
515519
patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching
516520
max_days: Maximum number of days for temporal positional encoding
517521
max_months: Maximum number of months for temporal positional encoding
522+
num_months: Number of months to predict (output channels in decoder)
518523
hidden: Hidden dimension used in the decoder
519524
overlap: Overlap for deconvolution in the decoder
520525
max_H: Maximum spatial height for 2D positional encoding
@@ -541,7 +546,7 @@ def __init__(
541546
patch_w=patch_size[2],
542547
hidden=hidden,
543548
overlap=overlap,
544-
num_months=max_months,
549+
num_months=num_months,
545550
)
546551
self.patch_size = patch_size
547552

@@ -552,7 +557,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
552557
daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
553558
data, where C is the number of channels (e.g., 1 for SST)
554559
daily_mask: Boolean tensor of same shape as daily_data indicating missing values
555-
land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output
560+
land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output
556561
padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
557562
(True for padded tokens). Used to mask out padded tokens in temporal attention.
558563
Returns:
@@ -574,18 +579,6 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
574579
)
575580
assert T % self.patch_size[0] == 0, "T must be divisible by patch size"
576581

577-
# Step 1: Encode spatio-temporal patches
578-
# each month independently by folding M into batch
579-
daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W)
580-
daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W)
581-
latent = self.encoder(
582-
daily_data_reshaped, daily_mask_reshaped
583-
) # (B*M, N_patches, embed_dim)
584-
585-
# Step 2: Aggregate temporal information for each spatial patch
586-
# latent -> (B, M*Np, embed_dim) to match the aggregator input x: (B, M*Tp*Hp*Wp, embed_dim)
587-
latent = latent.reshape(B, M * Np, -1)
588-
589582
if padded_days_mask is not None and self.patch_size[0] > 1:
590583
B, M, T_days = padded_days_mask.shape
591584
if T_days % self.patch_size[0] != 0:
@@ -596,23 +589,45 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
596589
B, M, T_days // self.patch_size[0], self.patch_size[0]
597590
).all(dim=-1) # (B, M, Tp)
598591

592+
# Step 1: Encode spatio-temporal patches
593+
# each month independently by folding M into batch
594+
# encoder input shape = (B, C, T, H, W) where C is channel.
595+
# encoder output shape = (B, N_patches, embed_dim)
596+
# so M is folded into B, and T, H, W are the spatio-temporal dimensions to be patched.
597+
daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W)
598+
daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W)
599+
600+
latent = self.encoder(
601+
daily_data_reshaped, daily_mask_reshaped
602+
) # (B*M, N_patches, embed_dim)
603+
604+
# Step 2: Aggregate temporal information for each spatial patch
605+
# temporal input shape = (B, M*T*H*W, C), C: embedding dimension
606+
# temporal output shape = (B, M, H*W, C) C: embedding dimension
607+
embed_dim = latent.shape[-1]
608+
latent = latent.view(B, M, Tp, Hp, Wp, embed_dim)
609+
599610
agg_latent = self.temporal(
600611
latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask
601-
) # (B, M*Hp*Wp, embed_dim)
612+
) # (B, M, Hp*Wp, embed_dim)
602613

603614
# Step 3: Add spatial positional encodings and mix spatial features
604-
E = agg_latent.shape[-1]
605-
agg_latent = agg_latent.view(B, M, Hp * Wp, E)
615+
# spatial PE output shape = (Hp, Wp, embed_dim)
606616
pe = (
607617
self.spatial_pe(Hp, Wp).to(agg_latent.device).to(agg_latent.dtype)
608-
) # (Hp*Wp, E)
609-
x = agg_latent + pe[None, None, :, :]
618+
) # (Hp, Wp, E)
619+
x = agg_latent + pe[None, None, :, :] # (B, M, Hp*Wp, E)
610620

611621
# Step 4: Spatial mixing with Transformer
612-
x = x.view(B * M, Hp * Wp, E)
613-
x = self.spatial_tr(x) # (B*M, Hp*Wp, E)
614-
x = x.view(B, M * Hp * Wp, E) # back to (B, M*Hp*Wp, E)
622+
# spatial transformer input shape = (B, N, C), output shape = (B, N, C) C: embedding dimension
623+
# M is folded in B.
624+
C = x.shape[-1]
625+
x = x.reshape(B * M, Hp * Wp, C)
626+
x = self.spatial_tr(x)
627+
x = x.view(B, M, Hp * Wp, C)
615628

616629
# Step 5: Decode to full-resolution 2D map
630+
# decoder input shape is (B, M*Hp*Wp, C), C: embedding dimension
631+
# decoder output shape is (B, M, H, W)
617632
monthly_pred = self.decoder(x, M, H, W, land_mask_patch) # (B, M, H, W)
618633
return monthly_pred

climanet/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
133133
"""
134134
pred: (B, M, H_pad,W_pad) or (B, H, W) torch tensor
135135
orig_H/W: original sizes before padding (optional)
136-
land_mask: (H_pad,W_pad) or (H,W) bool; if given, land will be set to NaN
136+
land_mask: (B, H_pad,W_pad) or (B, H,W) bool; if given, land will be set to NaN
137137
returns: (H,W) numpy array
138138
"""
139139
# crop to original size if provided
@@ -145,6 +145,8 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
145145
# set land to NaN (broadcast mask across batch)
146146
if land_mask is not None:
147147
pred = pred.clone().to(torch.float32)
148-
pred[:, :, land_mask.bool()] = float("nan")
148+
land_mask = land_mask.bool()
149+
land_mask = land_mask.unsqueeze(1) # (B, H,W) -> (B, 1, H, W) for broadcasting
150+
pred = torch.where(land_mask, torch.full_like(pred, float("nan")), pred)
149151

150152
return pred.detach().cpu().numpy()

0 commit comments

Comments
 (0)