Skip to content

Commit 4fb2336

Browse files
mjanuszcopybara-github
authored andcommitted
Preload image and mask in EstimateMissingFlow.
This avoids repeated volume reads, which can be slow even when the underlying data is cached in memory (due to cache trashing or need to reassemble the image array out of the underlying chunks). PiperOrigin-RevId: 859183961
1 parent 2ccc663 commit 4fb2336

2 files changed

Lines changed: 249 additions & 38 deletions

File tree

processor/flow.py

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Flow field estimation from SOFIMA."""
1616

1717
import dataclasses
18+
import gc
1819
import time
1920
from typing import Any, Sequence
2021

@@ -594,6 +595,7 @@ def __init__(
594595
)
595596

596597
self._config = config
598+
logging.info('EstimateMissingFlow running with config: %r', config)
597599

598600
def _build_mask(
599601
self,
@@ -661,6 +663,20 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
661663
out_box = out_box.adjusted_by(end=-offset)
662664
input_ndarray = input_ndarray[:, :, : out_box.size[1], : out_box.size[0]]
663665

666+
# The input flow forms the initial state of the output. We will try
667+
# to fill-in any invalid (NaN) pixels by computing flow against
668+
# earlier sections.
669+
ret = np.zeros([3] + list(out_box.size[::-1]))
670+
ret[:2, ...] = input_ndarray
671+
ret[2, ...] = self._config.delta_z
672+
673+
sel_mask = None
674+
if self._config.selection_mask_configs:
675+
sel_mask = self._build_mask(self._config.selection_mask_configs, out_box)
676+
677+
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
678+
invalid = np.isnan(input_ndarray[0, ...])
679+
664680
patch_size = self._config.patch_size
665681
curr_image_box = bounding_box.BoundingBox(
666682
start=(
@@ -671,25 +687,55 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
671687
size=(
672688
(out_box.size[0] - 1) * stride + patch_size,
673689
(out_box.size[1] - 1) * stride + patch_size,
674-
1,
690+
invalid.shape[0],
675691
),
676692
)
677693
curr_image_box = image_volume.clip_box_to_volume(curr_image_box)
678694
assert curr_image_box is not None
679695

680-
# The input flow forms the initial state of the output. We will try
681-
# to fill-in any invalid (NaN) pixels by computing flow against
682-
# earlier sections.
683-
ret = np.zeros([3] + list(out_box.size[::-1]))
684-
ret[:2, ...] = input_ndarray
685-
ret[2, ...] = self._config.delta_z
696+
if self._config.delta_z > 0:
697+
search_deltas = range(
698+
self._config.delta_z + 1, self._config.max_delta_z + 1
699+
)
700+
load_start_z = out_box.start[2] - self._config.max_delta_z
701+
load_end_z = out_box.end[2]
702+
else:
703+
search_deltas = range(
704+
self._config.delta_z - 1, self._config.max_delta_z - 1, -1
705+
)
706+
load_start_z = out_box.start[2]
707+
# max_delta_z is negative.
708+
load_end_z = out_box.end[2] - self._config.max_delta_z
686709

687-
sel_mask = None
688-
if self._config.selection_mask_configs:
689-
sel_mask = self._build_mask(self._config.selection_mask_configs, out_box)
710+
load_box = bounding_box.BoundingBox(
711+
start=(
712+
prev_image_box.start[0],
713+
prev_image_box.start[1],
714+
load_start_z,
715+
),
716+
size=(
717+
prev_image_box.size[0],
718+
prev_image_box.size[1],
719+
load_end_z - load_start_z,
720+
),
721+
)
722+
load_box = image_volume.clip_box_to_volume(load_box)
723+
724+
logging.info('Loading image data: %r', load_box)
725+
full_image_stack = image_volume.asarray[load_box.to_slice4d()][0, ...]
726+
full_mask = None
727+
if self._config.mask_configs:
728+
full_mask = self._build_mask(self._config.mask_configs, load_box)
729+
logging.info('Loaaded mask: %r', full_mask.shape)
730+
731+
# The 'curr' image is a subset of the loaded stack, centered within the
732+
# 'prev' image (which includes the search radius).
733+
curr_rel_start = curr_image_box.start - load_box.start
734+
curr_slice = (
735+
slice(curr_rel_start[1], curr_rel_start[1] + curr_image_box.size[1]),
736+
slice(curr_rel_start[0], curr_rel_start[0] + curr_image_box.size[0]),
737+
)
690738

691-
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
692-
invalid = np.isnan(input_ndarray[0, ...])
693739
for z in range(0, invalid.shape[0]):
694740
z0 = box.start[2] + z
695741
logging.info('Processing rel_z=%d abs_z=%d', z, z0)
@@ -698,12 +744,13 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
698744
beam_utils.counter(namespace, 'sections-already-valid').inc()
699745
continue
700746

701-
image_box = curr_image_box.translate([0, 0, z])
747+
curr_z_idx = (out_box.start[2] + z) - load_box.start[2]
748+
assert curr_z_idx >= 0
749+
assert curr_z_idx < full_image_stack.shape[0]
750+
702751
curr_mask = None
703752
if self._config.mask_configs:
704-
curr_mask = self._build_mask(
705-
self._config.mask_configs, image_box
706-
).squeeze()
753+
curr_mask = full_mask[curr_z_idx, ...][curr_slice]
707754
if np.all(curr_mask):
708755
beam_utils.counter(namespace, 'sections-masked').inc()
709756
continue
@@ -715,37 +762,23 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
715762
if sel_mask is not None:
716763
mask &= sel_mask[z, ...]
717764

718-
curr = image_volume.asarray[image_box.to_slice4d()].squeeze()
719-
720-
delta_z = self._config.delta_z
721-
if delta_z > 0:
722-
rng = range(delta_z + 1, self._config.max_delta_z + 1)
723-
else:
724-
rng = range(delta_z - 1, self._config.max_delta_z - 1, -1)
765+
curr = full_image_stack[curr_z_idx, ...][curr_slice]
725766

726-
for delta_z in rng:
727-
if (
728-
box.start[2] - delta_z < 0
729-
or box.end[2] - delta_z >= image_volume.volume_size[2]
730-
):
767+
for delta_z in search_deltas:
768+
prev_z_idx = curr_z_idx - delta_z
769+
if prev_z_idx < 0 or prev_z_idx >= full_image_stack.shape[0]:
731770
break
732771

733772
t_start = time.time()
734-
prev_box = prev_image_box.translate([0, 0, z - delta_z])
735-
logging.info('Trying delta_z=%d (%r)', delta_z, prev_box)
736-
prev = image_volume.asarray[prev_box.to_slice4d()].squeeze()
737-
logging.info('.. image loaded.')
773+
logging.info('Trying delta_z=%d', delta_z)
774+
prev_mask = None
775+
prev = full_image_stack[prev_z_idx, ...]
738776
t1 = time.time()
739777

740778
if self._config.mask_configs:
741-
prev_mask = self._build_mask(
742-
self._config.mask_configs, prev_box
743-
).squeeze()
779+
prev_mask = full_mask[prev_z_idx, ...]
744780
if np.all(prev_mask):
745781
continue
746-
else:
747-
prev_mask = None
748-
logging.info('.. mask loaded.')
749782

750783
# Limit the number of estimation attempts per voxel. Attempts
751784
# are only counted when voxels in both sections are unmasked.
@@ -804,4 +837,8 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
804837
t5 - t4,
805838
)
806839

840+
del full_image_stack
841+
del full_mask
842+
gc.collect()
843+
807844
return Subvolume(ret, out_box)

processor/flow_test.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# coding=utf-8
2+
# Copyright 2026 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from absl.testing import absltest
17+
from connectomics.common import bounding_box
18+
from connectomics.volume import subvolume
19+
import numpy as np
20+
from sofima.processor import flow
21+
22+
23+
class MockVolume:
24+
25+
def __init__(self, data):
26+
self._data = data # CZYX
27+
28+
def clip_box_to_volume(self, box):
29+
vol_box = bounding_box.BoundingBox(start=(0, 0, 0), size=self.volume_size)
30+
return box.intersection(vol_box)
31+
32+
@property
33+
def asarray(self):
34+
return self._data
35+
36+
@property
37+
def volume_size(self):
38+
# XYZ
39+
return (self._data.shape[3], self._data.shape[2], self._data.shape[1])
40+
41+
def __getitem__(self, key):
42+
return self._data[key]
43+
44+
45+
class TestEstimateMissingFlow(flow.EstimateMissingFlow):
46+
47+
def __init__(self, config, image_vol):
48+
super().__init__(config)
49+
self.image_vol = image_vol
50+
51+
def _open_volume(self, path):
52+
return self.image_vol
53+
54+
55+
class EstimateMissingFlowTest(absltest.TestCase):
56+
57+
def test_process(self):
58+
config = flow.EstimateMissingFlow.Config(
59+
patch_size=16,
60+
stride=16,
61+
delta_z=1,
62+
max_delta_z=2,
63+
max_attempts=1,
64+
mask_configs=None,
65+
mask_only_for_patch_selection=False,
66+
selection_mask_configs=None,
67+
min_peak_sharpness=0.0,
68+
min_peak_ratio=0.0,
69+
max_magnitude=0,
70+
batch_size=10, # Must be > 0 for batch processing
71+
image_volinfo="dummy_path",
72+
image_cache_bytes=0,
73+
mask_cache_bytes=0,
74+
search_radius=16,
75+
)
76+
77+
# Larger volume to avoid boundary clipping with required context size
78+
vol_shape = (1, 10, 128, 128)
79+
vol_data = np.random.rand(*vol_shape).astype(np.float32)
80+
81+
# Create a synthetic shift between z=3 and z=5.
82+
dx, dy = 2, 3
83+
prev_slice = vol_data[0, 3, :, :]
84+
shifted_slice = np.zeros_like(prev_slice)
85+
shifted_slice[dy:, dx:] = prev_slice[:-dy, :-dx]
86+
shifted_slice[:dy, :] = np.random.rand(dy, 128)
87+
shifted_slice[:, :dx] = np.random.rand(128, dx)
88+
89+
vol_data[0, 5, :, :] = shifted_slice
90+
91+
mock_vol = MockVolume(vol_data)
92+
processor = TestEstimateMissingFlow(config, mock_vol)
93+
94+
# Start at 2,2,5 (flow coords) corresponds to 32,32,5 (image coords).
95+
box = bounding_box.BoundingBox((2, 2, 5), (2, 2, 1))
96+
97+
# No pre-existing flow data.
98+
input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32)
99+
subvol = subvolume.Subvolume(input_data, box)
100+
101+
result_subvol = processor.process(subvol)
102+
103+
self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2))
104+
self.assertFalse(
105+
np.any(np.isnan(result_subvol.data)), "Result contains NaNs"
106+
)
107+
108+
np.testing.assert_allclose(
109+
result_subvol.data[2, ...], 2, err_msg="delta_z incorrect"
110+
)
111+
np.testing.assert_allclose(
112+
result_subvol.data[0, 0, 0, 0],
113+
-dx,
114+
atol=0.5,
115+
err_msg="Flow X incorrect",
116+
)
117+
np.testing.assert_allclose(
118+
result_subvol.data[1, 0, 0, 0],
119+
-dy,
120+
atol=0.5,
121+
err_msg="Flow Y incorrect",
122+
)
123+
124+
def test_process_clipped_context(self):
125+
config = flow.EstimateMissingFlow.Config(
126+
patch_size=16,
127+
stride=16,
128+
delta_z=1,
129+
max_delta_z=5, # Large lookback
130+
max_attempts=1,
131+
mask_configs=None,
132+
mask_only_for_patch_selection=False,
133+
selection_mask_configs=None,
134+
min_peak_sharpness=0.0,
135+
min_peak_ratio=0.0,
136+
max_magnitude=0,
137+
batch_size=10,
138+
image_volinfo="dummy_path",
139+
image_cache_bytes=0,
140+
mask_cache_bytes=0,
141+
search_radius=16,
142+
)
143+
144+
vol_shape = (1, 10, 128, 128)
145+
vol_data = np.random.rand(*vol_shape).astype(np.float32)
146+
147+
mock_vol = MockVolume(vol_data)
148+
processor = TestEstimateMissingFlow(config, mock_vol)
149+
150+
box = bounding_box.BoundingBox(start=(2, 2, 1), size=(2, 2, 1))
151+
152+
# No pre-existing flow data.
153+
input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32)
154+
subvol = subvolume.Subvolume(input_data, box)
155+
156+
result_subvol = processor.process(subvol)
157+
158+
self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2))
159+
160+
# Result should be NaNs because z=1 only has z=0 as valid prev.
161+
# delta_z=1 (matching z=0) was not calculated (assumed missing).
162+
# delta_z=2,3,4,5 look at z < 0, which is out of bounds.
163+
self.assertTrue(
164+
np.all(np.isnan(result_subvol.data[0, ...])), "Result X should be NaN"
165+
)
166+
self.assertTrue(
167+
np.all(np.isnan(result_subvol.data[1, ...])), "Result Y should be NaN"
168+
)
169+
# Channel 2 is initialized to delta_z (1).
170+
self.assertEqual(result_subvol.data[2, 0, 0, 0], 1)
171+
172+
173+
if __name__ == "__main__":
174+
absltest.main()

0 commit comments

Comments
 (0)