diff --git a/datasets-error-bounds/.gitignore b/datasets-error-bounds/.gitignore new file mode 100644 index 0000000..36fe1de --- /dev/null +++ b/datasets-error-bounds/.gitignore @@ -0,0 +1 @@ +/*/error_bounds.json diff --git a/scripts/compress.py b/scripts/compress.py index 5f7c365..e350823 100644 --- a/scripts/compress.py +++ b/scripts/compress.py @@ -1,10 +1,16 @@ import argparse import json +import traceback from pathlib import Path +from typing import Hashable import numcodecs_observers import xarray as xr -from climatebenchpress.compressor.compressors.abc import Compressor +from climatebenchpress.compressor.compressors.abc import ( + Compressor, + ErrorBound, + NamedPerVariableCodec, +) from dask.diagnostics.progress import ProgressBar from numcodecs.abc import Codec from numcodecs_combinators.stack import CodecStack @@ -19,6 +25,7 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compressor): datasets = REPO.parent / "data-loader" / "datasets" compressed_datasets = REPO / "compressed-datasets" + datasets_error_bounds = REPO / "datasets-error-bounds" for dataset in datasets.iterdir(): if dataset.name == ".gitignore" or dataset.name in exclude_dataset: @@ -27,43 +34,70 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compresso continue dataset /= "standardized.zarr" - + ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr") + ds_dtypes, ds_abs_mins, ds_abs_maxs = dict(), dict(), dict() + for v in ds: + abs_vals = xr.ufuncs.abs(ds[v]) + ds_abs_mins[v] = abs_vals.min().values.item() + ds_abs_maxs[v] = abs_vals.max().values.item() + ds_dtypes[v] = ds[v].dtype + + error_bounds = get_error_bounds(datasets_error_bounds, dataset.parent.name) for compressor in Compressor.registry.values(): if compressor.name in exclude_compressor: continue if include_compressor and compressor.name not in include_compressor: continue - compressed_dataset = ( - compressed_datasets / dataset.parent.name / compressor.name + compressor_variants: dict[str, list[NamedPerVariableCodec]] = ( + compressor.build(ds_dtypes, ds_abs_mins, ds_abs_maxs, error_bounds) ) - compressed_dataset.mkdir(parents=True, exist_ok=True) - - compressed_dataset_path = compressed_dataset / "decompressed.zarr" - - if compressed_dataset_path.exists(): - continue - - print( - f"Compressing {dataset.parent.name} with {compressor.description} ..." - ) - - ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr") - ds_new, measurements = compress_decompress(compressor.build(), ds) - - with (compressed_dataset / "measurements.json").open("w") as f: - json.dump(measurements, f) - - with ProgressBar(): - ds_new.to_zarr( - compressed_dataset_path, encoding=dict(), compute=False - ).compute() - - -def compress_decompress(codec: Codec, ds: xr.Dataset) -> tuple[xr.Dataset, dict]: - if not isinstance(codec, CodecStack): - codec = CodecStack(codec) + for compr_name, named_codecs in compressor_variants.items(): + for named_codec in named_codecs: + compressed_dataset = ( + compressed_datasets + / dataset.parent.name + / named_codec.name + / compr_name + ) + compressed_dataset.mkdir(parents=True, exist_ok=True) + + compressed_dataset_path = compressed_dataset / "decompressed.zarr" + + if compressed_dataset_path.exists(): + continue + + print( + f"Compressing {dataset.parent.name} with {compressor.description} ..." + ) + + try: + ds_new, measurements = compress_decompress( + named_codec.codecs, ds + ) + except Exception as e: + print( + f"Error compressing {dataset.parent.name} with {compressor.name}: {e}" + ) + with (compressed_dataset / "error.out").open("w") as error_file: + error_file.write(traceback.format_exc()) + print("Skipping...") + continue + + with (compressed_dataset / "measurements.json").open("w") as f: + json.dump(measurements, f) + + with ProgressBar(): + ds_new.to_zarr( + compressed_dataset_path, encoding=dict(), compute=False + ).compute() + + +def compress_decompress( + codecs: dict[Hashable, Codec], + ds: xr.Dataset, +) -> tuple[xr.Dataset, dict]: variables = dict() measurements = dict() @@ -72,6 +106,10 @@ def compress_decompress(codec: Codec, ds: xr.Dataset) -> tuple[xr.Dataset, dict] timing = WalltimeObserver() instructions = WasmCodecInstructionCounterObserver() + codec = codecs[v] + if not isinstance(codec, CodecStack): + codec = CodecStack(codec) + with numcodecs_observers.observe( codec, observers=[ @@ -104,6 +142,23 @@ def compress_decompress(codec: Codec, ds: xr.Dataset) -> tuple[xr.Dataset, dict] return xr.Dataset(variables, coords=ds.coords, attrs=ds.attrs), measurements +def get_error_bounds( + datasets_error_bounds: Path, dataset_name: str +) -> list[dict[str, ErrorBound]]: + if not datasets_error_bounds.exists(): + raise FileNotFoundError( + f"Expected error bounds to be defined in {datasets_error_bounds}. Run `scripts/create_error_bounds.py` to create them." + ) + + dataset_error_bounds = datasets_error_bounds / dataset_name + with open(dataset_error_bounds / "error_bounds.json") as f: + error_bounds = json.load(f) + return [ + {var_name: ErrorBound(**eb) for var_name, eb in eb_per_var.items()} + for eb_per_var in error_bounds + ] + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--exclude-dataset", type=str, nargs="+", default=[]) diff --git a/scripts/create_error_bounds.py b/scripts/create_error_bounds.py new file mode 100644 index 0000000..16ab3c6 --- /dev/null +++ b/scripts/create_error_bounds.py @@ -0,0 +1,43 @@ +import json +from pathlib import Path + +import xarray as xr + +REPO = Path(__file__).parent.parent + + +def main(): + datasets = REPO.parent / "data-loader" / "datasets" + datasets_error_bounds = REPO / "datasets-error-bounds" + + for dataset in datasets.iterdir(): + if dataset.name == ".gitignore": + continue + + print(dataset.name) + ds = xr.open_dataset( + dataset / "standardized.zarr", + chunks=dict(), + engine="zarr", + decode_times=False, + ) + + # TODO: This is a temporary solution that should be replaced by a more + # principled method to selct the error bounds. + low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict() + for v in ds: + data_range = (ds[v].max() - ds[v].min()).values.item() + low_error_bounds[v] = {"abs_error": 0.0001 * data_range, "rel_error": None} + mid_error_bounds[v] = {"abs_error": 0.001 * data_range, "rel_error": None} + high_error_bounds[v] = {"abs_error": 0.01 * data_range, "rel_error": None} + + error_bounds = [low_error_bounds, mid_error_bounds, high_error_bounds] + + dataset_error_bounds = datasets_error_bounds / dataset.name + dataset_error_bounds.mkdir(parents=True, exist_ok=True) + with open(dataset_error_bounds / "error_bounds.json", "w") as f: + json.dump(error_bounds, f) + + +if __name__ == "__main__": + main() diff --git a/src/climatebenchpress/compressor/compressors/abc.py b/src/climatebenchpress/compressor/compressors/abc.py index 63039f8..1a16813 100644 --- a/src/climatebenchpress/compressor/compressors/abc.py +++ b/src/climatebenchpress/compressor/compressors/abc.py @@ -1,12 +1,54 @@ -__all__ = ["Compressor"] +__all__ = ["Compressor", "NamedPerVariableCodec", "ErrorBound"] from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Mapping +from dataclasses import dataclass from types import MappingProxyType +from typing import Optional +import numpy as np from numcodecs.abc import Codec from typed_classproperties import classproperty +type ErrorBoundName = str +type VariableName = str +type VariantName = str + + +@dataclass +class NamedPerVariableCodec: + name: ErrorBoundName + codecs: dict[VariableName, Codec] + + +@dataclass +class ErrorBound: + abs_error: Optional[float] = None + rel_error: Optional[float] = None + + def __post_init__(self): + if self.abs_error is not None and self.rel_error is not None: + raise ValueError( + "Only one of 'abs_error' or 'rel_error' can be specified, not both." + ) + if self.abs_error is None and self.rel_error is None: + raise ValueError( + "At least one of 'abs_error' or 'rel_error' must be specified." + ) + + self.name = ( + f"abs_error={self.abs_error}" + if self.abs_error is not None + else f"rel_error={self.rel_error}" + ) + + +@dataclass +class VariantErrorBoundPerVariable: + name: VariantName + error_bounds: dict[VariableName, ErrorBound] + class Compressor(ABC): # Abstract interface, must be implemented by subclasses @@ -15,9 +57,86 @@ class Compressor(ABC): @staticmethod @abstractmethod - def build() -> Codec: + def abs_bound_codec(dtype: np.dtype, error_bound: float) -> Codec: pass + @staticmethod + @abstractmethod + def rel_bound_codec(dtype: np.dtype, error_bound: float) -> Codec: + pass + + @classmethod + def build( + cls, + dtypes: dict[VariableName, np.dtype], + data_abs_min: dict[VariableName, float], + data_abs_max: dict[VariableName, float], + error_bounds: list[dict[VariableName, ErrorBound]], + ) -> dict[VariantName, list[NamedPerVariableCodec]]: + """ + Constructs a dictionary of codecs based on the provided error bounds. + The dictionary has a separate entry for each compressor variant. Compressor + variants are created when transforming between absolute and relative + error bounds (each variant accounts for a different way to transform the + error bound). The dictionary values are lists of `NamedPerVariableCodec` instances + where each element in the list corresponds to a different value for the + error bound. + + Parameters + ---------- + dtype : dict[VariableName, numpy.dtype] + Dict mapping from variable name to data type of the input data. + data_abs_min : dict[VariableName, float] + Dict mapping from variable name to minimum absolute value for the variable. + data_abs_max : dict[VariableName, float] + Dict mapping from variable name to maximum absolute value for the variable. + error_bounds: list[ErrorBound] + List of error bounds to use for the compressor. + + Returns + ------- + dict[VariantName, list[NamedPerVariableCodec]] + A dictionary where keys are codec variant names (for separate error bound conversions) + and values are lists of `NamedPerVariableCodec` instances configured with the specified error bounds. + """ + codecs: dict[VariantName, list[NamedPerVariableCodec]] = defaultdict(list) + transformed_bounds: list[VariantErrorBoundPerVariable] = [] + + # Loop over all the error bounds and ensure that they are compatible with the + # compressor. If the error bound is not compatible, transform it into a new + # error bound that is compatible. + for eb_per_var in error_bounds: + transformed_bounds += cls._get_variant_bounds( + data_abs_min, data_abs_max, cls.name, eb_per_var + ) + + # For each error bound, create a new codec. + for variant_info in transformed_bounds: + variant_name, eb_per_var = variant_info.name, variant_info.error_bounds + new_codecs: dict[VariableName, Codec] = dict() + for var, eb in eb_per_var.items(): + if eb.abs_error is not None and cls.has_abs_error_impl: + new_codecs[var] = cls.abs_bound_codec(dtypes[var], eb.abs_error) + elif eb.rel_error is not None and cls.has_rel_error_impl: + new_codecs[var] = cls.rel_bound_codec(dtypes[var], eb.rel_error) + else: + # This should never happen as we have already transformed the error bounds. + # If this happens, it means there is a bug in the implementation. + # We raise an error here to avoid silent failures. + raise ValueError( + "Error bound is not compatible with the compressor." + ) + + # Sort the error bounds by variable name to ensure consistent ordering. + error_bound_name = "_".join( + f"{var}-{eb.name}" for var, eb in sorted(eb_per_var.items()) + ) + codecs[variant_name].append( + NamedPerVariableCodec(name=error_bound_name, codecs=new_codecs) + ) + + return codecs + # Class interface @classproperty def registry(cls) -> Mapping: @@ -26,6 +145,14 @@ def registry(cls) -> Mapping: # Implementation details _registry: dict[str, type["Compressor"]] = dict() + @classproperty + def has_abs_error_impl(cls) -> bool: + return "abs_bound_codec" in cls.__dict__ + + @classproperty + def has_rel_error_impl(cls) -> bool: + return "rel_bound_codec" in cls.__dict__ + @classmethod def __init_subclass__(cls: type["Compressor"]) -> None: name = getattr(cls, "name", None) @@ -33,6 +160,11 @@ def __init_subclass__(cls: type["Compressor"]) -> None: if name is None: raise TypeError(f"Compressor {cls} must have a name") + if not (cls.has_abs_error_impl or cls.has_rel_error_impl): + raise TypeError( + f"Compressor {cls} must implement at least one of `abs_bound_codec` and `rel_bound_codec`." + ) + if name in Compressor._registry: raise TypeError( f"duplicate Compressor name {name} for {cls} vs {Compressor._registry[name]}" @@ -41,3 +173,118 @@ def __init_subclass__(cls: type["Compressor"]) -> None: Compressor._registry[name] = cls return super().__init_subclass__() + + @classmethod + def _get_variant_bounds( + cls, + data_abs_min: dict[VariableName, float], + data_abs_max: dict[VariableName, float], + variant_name: VariantName, + error_bounds: dict[VariableName, ErrorBound], + ) -> list[VariantErrorBoundPerVariable]: + """ + Check whether the supplied `error_bounds` are compatible with the current + compressor. If they are not compatible return a list of new transformed + error bounds. + """ + converted_bounds: dict[VariableName, dict[VariantName, ErrorBound]] = dict() + variant_names = {cls.name} + for var, error_bound in error_bounds.items(): + abs_bound_codec = ( + error_bound.abs_error is not None and cls.has_abs_error_impl + ) + rel_bound_codec = ( + error_bound.rel_error is not None and cls.has_rel_error_impl + ) + if abs_bound_codec or rel_bound_codec: + # If codec is compatible with the error bound no transformation + # is needed. + continue + + converted_bounds[var] = convert_error_bound( + variant_name, data_abs_min[var], data_abs_max[var], error_bound + ) + if variant_names == {cls.name}: + # This is the first time we are transforming the error bounds, + # therefore we need to update the names of the generated variants. + variant_names = set(converted_bounds[var].keys()) + else: + # For all the variables if we are converting the error bounds + # they should lead to the same number of variants. + # If this is not the case, we are somehow using different mechanisms + # to transform the same type of error bound which should be avoided. + # This holds true as long as we have only two types of error bounds + # (absolute and relative). If we add more types of error bounds then + # this property no longer holds. + assert variant_names == set(converted_bounds[var].keys()), ( + "Error bounds for different variables must have the same variant names." + ) + + if len(converted_bounds) == 0: + # The error bounds for all variables are compatible with the codec. + # Just return the original error bounds. + return [VariantErrorBoundPerVariable(variant_name, error_bounds)] + + # converted_bounds contains entries for all variables for which we needed + # to transform the error bounds. We now transform the dictionary + # dict[VariableName, dict[VariantName, ErrorBound]] into a list in which + # each entry represents one way to transform the error bound (i.e. one + # *variant* of the error bound). Additionally, each variant needs to contain + # information about the error bounds for all variables. + variable_names = set(error_bounds.keys()) + result: list[VariantErrorBoundPerVariable] = [] + for variant in variant_names: + eb_per_variable: dict[VariableName, ErrorBound] = dict() + for variable in variable_names: + if variable in converted_bounds: + eb_per_variable[variable] = converted_bounds[variable][variant] + else: + eb_per_variable[variable] = error_bounds[variable] + result.append( + VariantErrorBoundPerVariable(name=variant, error_bounds=eb_per_variable) + ) + + return result + + +def convert_error_bound( + name: str, + data_abs_min: float, + data_abs_max: float, + error_bound: ErrorBound, +) -> dict[VariantName, ErrorBound]: + if error_bound.abs_error is not None: + new_ebs = convert_abs_error_to_rel_error(name, data_abs_max, error_bound) + else: + new_ebs = convert_rel_error_to_abs_error(name, data_abs_min, error_bound) + + # Keep the old name for all the new error bounds. This ensures we can group + # together all transformed error bounds that came from the same original bound. + for n in new_ebs.keys(): + new_ebs[n].name = error_bound.name + + return new_ebs + + +def convert_rel_error_to_abs_error( + name: str, data_abs_min: float, old_error: ErrorBound +) -> dict[VariantName, ErrorBound]: + # In general, rel_error = abs_error / abs(data). This transformation + # gives us the relative error bound that ensures the absolute error bound is + # not exceeded for this dataset. + assert old_error.rel_error is not None, "Expected relative error to be set." + + new_name = f"{name}-conservative-abs" + error_bound = ErrorBound(abs_error=old_error.rel_error * data_abs_min) + return {new_name: error_bound} + + +def convert_abs_error_to_rel_error( + name: str, data_abs_max: float, old_error: ErrorBound +) -> dict[VariantName, ErrorBound]: + # Same reasoning for error bound transformation as in `convert_rel_error_to_abs_error`. + assert old_error.abs_error is not None, "Expected absolute error to be set." + + new_name = f"{name}-conservative-rel" + error_bound = ErrorBound(rel_error=old_error.abs_error / data_abs_max) + return {new_name: error_bound} diff --git a/src/climatebenchpress/compressor/compressors/bitround.py b/src/climatebenchpress/compressor/compressors/bitround.py index 650d825..599a7a9 100644 --- a/src/climatebenchpress/compressor/compressors/bitround.py +++ b/src/climatebenchpress/compressor/compressors/bitround.py @@ -2,10 +2,10 @@ import numcodecs_wasm_bit_round import numcodecs_wasm_zlib -from numcodecs.abc import Codec from numcodecs_combinators.stack import CodecStack from .abc import Compressor +from .utils import compute_keepbits class BitRound(Compressor): @@ -13,8 +13,9 @@ class BitRound(Compressor): description = "Bit Rounding" @staticmethod - def build() -> Codec: + def rel_bound_codec(dtype, error_bound): + keepbits = compute_keepbits(dtype, error_bound) return CodecStack( - numcodecs_wasm_bit_round.BitRound(keepbits=9), + numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), numcodecs_wasm_zlib.Zlib(level=6), ) diff --git a/src/climatebenchpress/compressor/compressors/bitround_pco.py b/src/climatebenchpress/compressor/compressors/bitround_pco.py index abc2369..5fda857 100644 --- a/src/climatebenchpress/compressor/compressors/bitround_pco.py +++ b/src/climatebenchpress/compressor/compressors/bitround_pco.py @@ -1,11 +1,12 @@ __all__ = ["BitRoundPco"] + import numcodecs_wasm_bit_round import numcodecs_wasm_pco -from numcodecs.abc import Codec from numcodecs_combinators.stack import CodecStack from .abc import Compressor +from .utils import compute_keepbits class BitRoundPco(Compressor): @@ -13,9 +14,10 @@ class BitRoundPco(Compressor): description = "Bit Rounding + PCodec" @staticmethod - def build() -> Codec: + def rel_bound_codec(dtype, error_bound): + keepbits = compute_keepbits(dtype, error_bound) return CodecStack( - numcodecs_wasm_bit_round.BitRound(keepbits=9), + numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), numcodecs_wasm_pco.Pco( level=8, mode="auto", diff --git a/src/climatebenchpress/compressor/compressors/jpeg2000.py b/src/climatebenchpress/compressor/compressors/jpeg2000.py index 9e0e950..fd643fc 100644 --- a/src/climatebenchpress/compressor/compressors/jpeg2000.py +++ b/src/climatebenchpress/compressor/compressors/jpeg2000.py @@ -1,10 +1,11 @@ __all__ = ["Jpeg2000"] +import math + import numcodecs.astype import numcodecs_wasm_fixed_offset_scale import numcodecs_wasm_jpeg2000 import numcodecs_wasm_round -from numcodecs.abc import Codec from numcodecs_combinators.stack import CodecStack from .abc import Compressor @@ -15,9 +16,19 @@ class Jpeg2000(Compressor): description = "JPEG 2000" @staticmethod - def build() -> Codec: - precision = 0.01 - rate = 10.0 # x10 factor compression + def abs_bound_codec(dtype, error_bound): + # Currently, the input is transformed into the range + # round(min_pixel_val/ error_bound) <= x <= round(max_pixel_val / error_bound) + # This means any values outside this range will incur a larger error. + precision = error_bound + max_pixel_val = 2**25 - 1 # maximum pixel value for our integer encoding. + + # Here we use the formula for the PSNR (https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) + # to convert between the absolute error and the PSNR value. + # The original PSNR formula uses the root mean square error (RMSE), + # therefore JPEG does not guaruantee pointwise error bounds but only + # average error bounds. + psnr = 20 * (math.log10(max_pixel_val) - math.log10(error_bound)) return CodecStack( numcodecs_wasm_fixed_offset_scale.FixedOffsetScale( @@ -29,5 +40,5 @@ def build() -> Codec: encode_dtype="int32", decode_dtype="float32", ), - numcodecs_wasm_jpeg2000.Jpeg2000(mode="rate", rate=rate), + numcodecs_wasm_jpeg2000.Jpeg2000(mode="psnr", psnr=psnr), ) diff --git a/src/climatebenchpress/compressor/compressors/stochround.py b/src/climatebenchpress/compressor/compressors/stochround.py index 5eeb406..bcdab95 100644 --- a/src/climatebenchpress/compressor/compressors/stochround.py +++ b/src/climatebenchpress/compressor/compressors/stochround.py @@ -3,7 +3,6 @@ import numcodecs_wasm_round import numcodecs_wasm_uniform_noise import numcodecs_wasm_zlib -from numcodecs.abc import Codec from numcodecs_combinators.stack import CodecStack from .abc import Compressor @@ -14,9 +13,8 @@ class StochRound(Compressor): description = "Stochastic Rounding" @staticmethod - def build() -> Codec: - precision = 0.01 - + def abs_bound_codec(dtype, error_bound): + precision = error_bound return CodecStack( numcodecs_wasm_uniform_noise.UniformNoise(scale=precision / 2, seed=42), numcodecs_wasm_round.Round(precision=precision), diff --git a/src/climatebenchpress/compressor/compressors/sz3.py b/src/climatebenchpress/compressor/compressors/sz3.py index f643838..ddc148b 100644 --- a/src/climatebenchpress/compressor/compressors/sz3.py +++ b/src/climatebenchpress/compressor/compressors/sz3.py @@ -1,7 +1,6 @@ __all__ = ["Sz3"] import numcodecs_wasm_sz3 -from numcodecs.abc import Codec from .abc import Compressor @@ -11,5 +10,13 @@ class Sz3(Compressor): description = "SZ3" @staticmethod - def build() -> Codec: - return numcodecs_wasm_sz3.Sz3(eb_mode="abs", eb_abs=0.01) + def abs_bound_codec(dtype, error_bound): + return numcodecs_wasm_sz3.Sz3(eb_mode="abs", eb_abs=error_bound) + + @staticmethod + def rel_bound_codec(dtype, error_bound): + # SZ3 will not ensure that the relative error bound is strictly met. + # Internally, SZ3 transforms the relative error bound to an absolute error bound + # based on the range of the input data: + # https://github.com/szcompressor/SZ3/blob/e8a6b1569067abdd6b7d4276e91eced115be4f14/include/SZ3/utils/Statistic.hpp#L36 + return numcodecs_wasm_sz3.Sz3(eb_mode="rel", eb_rel=error_bound) diff --git a/src/climatebenchpress/compressor/compressors/tthresh.py b/src/climatebenchpress/compressor/compressors/tthresh.py index de9599e..68ba853 100644 --- a/src/climatebenchpress/compressor/compressors/tthresh.py +++ b/src/climatebenchpress/compressor/compressors/tthresh.py @@ -1,7 +1,6 @@ __all__ = ["Tthresh"] import numcodecs_wasm_tthresh -from numcodecs.abc import Codec from .abc import Compressor @@ -11,5 +10,9 @@ class Tthresh(Compressor): description = "tthresh" @staticmethod - def build() -> Codec: - return numcodecs_wasm_tthresh.Tthresh(eb_mode="rmse", eb_rmse=0.0001) + def abs_bound_codec(dtype, error_bound): + return numcodecs_wasm_tthresh.Tthresh(eb_mode="rmse", eb_rmse=error_bound) + + @staticmethod + def rel_bound_codec(dtype, error_bound): + return numcodecs_wasm_tthresh.Tthresh(eb_mode="eps", eb_rmse=error_bound) diff --git a/src/climatebenchpress/compressor/compressors/utils.py b/src/climatebenchpress/compressor/compressors/utils.py new file mode 100644 index 0000000..bee70dc --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/utils.py @@ -0,0 +1,39 @@ +__all__ = [ + "compute_keepbits", +] + +import math + +import numpy as np + +MANTISSA_BITS = { + np.dtype("float32"): 23, + np.dtype("float64"): 52, + np.dtype("float16"): 10, +} + + +def compute_keepbits(dtype: np.dtype, rel_error: float) -> int: + """ + Computes the number of mantissa bits to keep in order to satisfy a relative error bound. + + Parameters + ---------- + dtype : numpy.dtype + Data type of the input array. + rel_error : float + Relative error bound. + + Returns + ------- + int + Number of mantissa bits to keep. + """ + # - log2(rel_error) specifies the number of mantissa bits needed to satisfy + # the rel_error bound (https://en.wikipedia.org/wiki/Machine_epsilon). + # We need to round up to the nearest integer to ensure the error bound is not + # exceeded. + keepbits = -math.floor(math.log2(rel_error)) - 1 + # Ensure that keepbits is within the range of the mantissa bits of single precision. + keepbits = max(min(keepbits, MANTISSA_BITS[dtype]), 0) + return keepbits diff --git a/src/climatebenchpress/compressor/compressors/zfp.py b/src/climatebenchpress/compressor/compressors/zfp.py index b670152..a6611b8 100644 --- a/src/climatebenchpress/compressor/compressors/zfp.py +++ b/src/climatebenchpress/compressor/compressors/zfp.py @@ -1,7 +1,6 @@ __all__ = ["Zfp"] import numcodecs_wasm_zfp -from numcodecs.abc import Codec from .abc import Compressor @@ -10,6 +9,14 @@ class Zfp(Compressor): name = "zfp" description = "ZFP" + # NOTE: + # ZFP mechanism for strictly supporting relative error bounds is to + # truncate the floating point bit representation and then use ZFP's lossless + # mode for compression. This is essentially equivalent to the BitRound + # compressors we are already implementing (with a difference what the lossless + # compression algorithm is). + # See https://zfp.readthedocs.io/en/release1.0.1/faq.html#q-relerr for more details. + @staticmethod - def build() -> Codec: - return numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=0.01) + def abs_bound_codec(dtype, error_bound): + return numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=error_bound)