Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4e585fe
Add directory to store error bounds
treigerm Mar 25, 2025
e07f66c
Allow to configure each compressor with an error bound
treigerm Mar 25, 2025
1f43aed
Pass dtype to build fn, add docstrings, create bitround helper fn
treigerm Mar 26, 2025
ce5a989
Enforce named arguments
treigerm Mar 28, 2025
7c733b3
Refactor compressors to allow multiple error bound conversions
treigerm Mar 31, 2025
5ebe08b
Refactor compressor to have transformation logic in base class
treigerm Mar 31, 2025
a3530ba
Add docstring
treigerm Mar 31, 2025
cd9ef48
Merge remote-tracking branch 'origin/main' into error_bounds
treigerm Mar 31, 2025
8ba9df2
Adjust JPEG2000 for absolute error bounds
treigerm Mar 31, 2025
9b0001c
Refine docstring
treigerm Mar 31, 2025
1ec0fe3
Fix compressed datasets path
treigerm Mar 31, 2025
d4f16f3
Fix grammar in docstring
treigerm Apr 1, 2025
de6fd17
Fix relative error bound conversion bug
treigerm Apr 1, 2025
84c113a
Improved comments and error handling
treigerm Apr 1, 2025
3e6777b
Generate separate codec for each variable
treigerm Apr 2, 2025
6309efb
Fix JPEG2000 maximum pixel value
treigerm Apr 2, 2025
be3805a
Save full stacktrace when error occurs
treigerm Apr 2, 2025
bcf7f3f
Clarify control flow, address PR review comments
treigerm Apr 2, 2025
388ac13
Simplify control flow further
treigerm Apr 3, 2025
445df64
Adjust JPEG2000 precision
treigerm Apr 3, 2025
1383d9a
Clarifying comments
treigerm Apr 3, 2025
970f5ea
Comment about input transformation
treigerm Apr 3, 2025
d2a864d
Rename dataclasses with more detailed names
treigerm Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datasets-error-bounds/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/*/error_bounds.json
115 changes: 85 additions & 30 deletions scripts/compress.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import argparse
import json
import traceback
from pathlib import Path
from typing import Hashable

import numcodecs_observers
import xarray as xr
from climatebenchpress.compressor.compressors.abc import Compressor
from climatebenchpress.compressor.compressors.abc import (
Compressor,
ErrorBound,
NamedPerVariableCodec,
)
from dask.diagnostics.progress import ProgressBar
from numcodecs.abc import Codec
from numcodecs_combinators.stack import CodecStack
Expand All @@ -19,6 +25,7 @@
def main(exclude_dataset, include_dataset, exclude_compressor, include_compressor):
datasets = REPO.parent / "data-loader" / "datasets"
compressed_datasets = REPO / "compressed-datasets"
datasets_error_bounds = REPO / "datasets-error-bounds"

for dataset in datasets.iterdir():
if dataset.name == ".gitignore" or dataset.name in exclude_dataset:
Expand All @@ -27,43 +34,70 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compresso
continue

dataset /= "standardized.zarr"

ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr")
ds_dtypes, ds_abs_mins, ds_abs_maxs = dict(), dict(), dict()
for v in ds:
abs_vals = xr.ufuncs.abs(ds[v])
ds_abs_mins[v] = abs_vals.min().values.item()
ds_abs_maxs[v] = abs_vals.max().values.item()
ds_dtypes[v] = ds[v].dtype

error_bounds = get_error_bounds(datasets_error_bounds, dataset.parent.name)
for compressor in Compressor.registry.values():
if compressor.name in exclude_compressor:
continue
if include_compressor and compressor.name not in include_compressor:
continue

compressed_dataset = (
compressed_datasets / dataset.parent.name / compressor.name
compressor_variants: dict[str, list[NamedPerVariableCodec]] = (
compressor.build(ds_dtypes, ds_abs_mins, ds_abs_maxs, error_bounds)
)
compressed_dataset.mkdir(parents=True, exist_ok=True)

compressed_dataset_path = compressed_dataset / "decompressed.zarr"

if compressed_dataset_path.exists():
continue

print(
f"Compressing {dataset.parent.name} with {compressor.description} ..."
)

ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr")
ds_new, measurements = compress_decompress(compressor.build(), ds)

with (compressed_dataset / "measurements.json").open("w") as f:
json.dump(measurements, f)

with ProgressBar():
ds_new.to_zarr(
compressed_dataset_path, encoding=dict(), compute=False
).compute()


def compress_decompress(codec: Codec, ds: xr.Dataset) -> tuple[xr.Dataset, dict]:
if not isinstance(codec, CodecStack):
codec = CodecStack(codec)

for compr_name, named_codecs in compressor_variants.items():
for named_codec in named_codecs:
compressed_dataset = (
compressed_datasets
/ dataset.parent.name
/ named_codec.name
/ compr_name
)
compressed_dataset.mkdir(parents=True, exist_ok=True)

compressed_dataset_path = compressed_dataset / "decompressed.zarr"

if compressed_dataset_path.exists():
continue

print(
f"Compressing {dataset.parent.name} with {compressor.description} ..."
)

try:
ds_new, measurements = compress_decompress(
named_codec.codecs, ds
)
except Exception as e:
print(
Comment thread
treigerm marked this conversation as resolved.
f"Error compressing {dataset.parent.name} with {compressor.name}: {e}"
)
with (compressed_dataset / "error.out").open("w") as error_file:
error_file.write(traceback.format_exc())
print("Skipping...")
continue

with (compressed_dataset / "measurements.json").open("w") as f:
json.dump(measurements, f)

with ProgressBar():
ds_new.to_zarr(
compressed_dataset_path, encoding=dict(), compute=False
).compute()


def compress_decompress(
codecs: dict[Hashable, Codec],
ds: xr.Dataset,
) -> tuple[xr.Dataset, dict]:
variables = dict()
measurements = dict()

Expand All @@ -72,6 +106,10 @@ def compress_decompress(codec: Codec, ds: xr.Dataset) -> tuple[xr.Dataset, dict]
timing = WalltimeObserver()
instructions = WasmCodecInstructionCounterObserver()

codec = codecs[v]
if not isinstance(codec, CodecStack):
codec = CodecStack(codec)

with numcodecs_observers.observe(
codec,
observers=[
Expand Down Expand Up @@ -104,6 +142,23 @@ def compress_decompress(codec: Codec, ds: xr.Dataset) -> tuple[xr.Dataset, dict]
return xr.Dataset(variables, coords=ds.coords, attrs=ds.attrs), measurements


def get_error_bounds(
Comment thread
juntyr marked this conversation as resolved.
datasets_error_bounds: Path, dataset_name: str
) -> list[dict[str, ErrorBound]]:
if not datasets_error_bounds.exists():
raise FileNotFoundError(
f"Expected error bounds to be defined in {datasets_error_bounds}. Run `scripts/create_error_bounds.py` to create them."
)

dataset_error_bounds = datasets_error_bounds / dataset_name
with open(dataset_error_bounds / "error_bounds.json") as f:
error_bounds = json.load(f)
return [
{var_name: ErrorBound(**eb) for var_name, eb in eb_per_var.items()}
for eb_per_var in error_bounds
]


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--exclude-dataset", type=str, nargs="+", default=[])
Expand Down
43 changes: 43 additions & 0 deletions scripts/create_error_bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
Comment thread
juntyr marked this conversation as resolved.
from pathlib import Path

import xarray as xr

REPO = Path(__file__).parent.parent


def main():
datasets = REPO.parent / "data-loader" / "datasets"
datasets_error_bounds = REPO / "datasets-error-bounds"

for dataset in datasets.iterdir():
if dataset.name == ".gitignore":
continue

print(dataset.name)
ds = xr.open_dataset(
dataset / "standardized.zarr",
chunks=dict(),
engine="zarr",
decode_times=False,
)

# TODO: This is a temporary solution that should be replaced by a more
# principled method to selct the error bounds.
low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict()
for v in ds:
data_range = (ds[v].max() - ds[v].min()).values.item()
low_error_bounds[v] = {"abs_error": 0.0001 * data_range, "rel_error": None}
mid_error_bounds[v] = {"abs_error": 0.001 * data_range, "rel_error": None}
high_error_bounds[v] = {"abs_error": 0.01 * data_range, "rel_error": None}

error_bounds = [low_error_bounds, mid_error_bounds, high_error_bounds]

dataset_error_bounds = datasets_error_bounds / dataset.name
dataset_error_bounds.mkdir(parents=True, exist_ok=True)
with open(dataset_error_bounds / "error_bounds.json", "w") as f:
json.dump(error_bounds, f)


if __name__ == "__main__":
main()
Loading
Loading