diff --git a/changes/4101.misc.md b/changes/4101.misc.md new file mode 100644 index 0000000000..459ebad2cc --- /dev/null +++ b/changes/4101.misc.md @@ -0,0 +1,12 @@ +Replaced the `donfig`-based configuration with a statically-typed +configuration object. `zarr.config` now provides precise static types for +attribute access (`zarr.config.array.order`) and for the dotted-string API +(`zarr.config.get("array.order")`). The string API, environment-variable +ingestion (`ZARR_FOO__BAR`), YAML config files, `config.set` (permanent and +as a context manager), `config.reset`, `config.enable_gpu`, and the +`deprecations` mechanism are all preserved. The `donfig` dependency has been +removed. + +Note: `zarr.config.defaults` now returns a nested `dict` directly; donfig +previously returned a one-element `list[dict]`, so callers that used +`config.defaults[0]` must be updated to use `config.defaults`. diff --git a/docs/user-guide/config.md b/docs/user-guide/config.md index 71c021b070..554c34917d 100644 --- a/docs/user-guide/config.md +++ b/docs/user-guide/config.md @@ -1,7 +1,8 @@ # Runtime configuration -[`zarr.config`][] is responsible for managing the configuration of zarr and -is based on the [donfig](https://github.com/pytroll/donfig) Python library. +[`zarr.config`][] is a `ZarrConfigManager` instance that manages all runtime +settings for zarr. It provides both typed attribute access and a dotted-string +key API. Configuration values can be set using code like the following: @@ -18,12 +19,13 @@ zarr.config.set({'array.order': 'F'}) print(zarr.config.get('array.order')) ``` -Alternatively, configuration values can be set using environment variables, e.g. +Alternatively, configuration values can be set using environment variables. +The variable name uses a `ZARR_` prefix, with `__` to denote nesting, e.g. `ZARR_ARRAY__ORDER=F`. -The configuration can also be read from a YAML file in standard locations. -For more information, see the -[donfig documentation](https://donfig.readthedocs.io/en/latest/). +The configuration can also be read from YAML files. Place a `zarr.yaml` (or +any `.yaml`/`.yml` file) in `~/.config/zarr/`, or point the `ZARR_CONFIG` +environment variable at a specific file path. Configuration options include the following: @@ -46,8 +48,5 @@ This is the current default configuration: ```python exec="true" session="config" source="above" result="ansi" from pprint import pprint -import io -output = io.StringIO() -zarr.config.pprint(stream=output, width=60) -print(output.getvalue()) +pprint(zarr.config.to_dict()) ``` diff --git a/docs/user-guide/installation.md b/docs/user-guide/installation.md index c902acf171..a710f417bc 100644 --- a/docs/user-guide/installation.md +++ b/docs/user-guide/installation.md @@ -10,7 +10,7 @@ Required dependencies include: - [numcodecs](https://numcodecs.readthedocs.io) (0.14 or later) - [google-crc32c](https://github.com/googleapis/python-crc32c) (1.5 or later) - [typing_extensions](https://typing-extensions.readthedocs.io) (4.9 or later) -- [donfig](https://donfig.readthedocs.io) (0.8 or later) +- [pyyaml](https://pyyaml.org) (6 or later) ## pip diff --git a/pyproject.toml b/pyproject.toml index 6a7238ff8f..1479265884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'numcodecs>=0.14', 'google-crc32c>=1.5', 'typing_extensions>=4.14', - 'donfig>=0.8', + 'pyyaml>=6', ] dynamic = [ @@ -242,7 +242,6 @@ extra-dependencies = [ 's3fs @ git+https://github.com/fsspec/s3fs', 'universal_pathlib @ git+https://github.com/fsspec/universal_pathlib', 'typing_extensions @ git+https://github.com/python/typing_extensions', - 'donfig @ git+https://github.com/pytroll/donfig', 'obstore @ git+https://github.com/developmentseed/obstore@main#subdirectory=obstore', ] @@ -268,7 +267,7 @@ extra-dependencies = [ 's3fs==2023.10.0', 'universal_pathlib==0.2.0', 'typing_extensions==4.14.*', - 'donfig==0.8.*', + 'pyyaml==6.*', 'obstore==0.5.*', ] diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index cdf3840c3b..3322f76c2f 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -68,7 +68,7 @@ def print_packages(packages: list[str]) -> None: "numpy", "numcodecs", "typing_extensions", - "donfig", + "pyyaml", ] optional = [ "botocore", diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 08d2a50ace..cdc27f25ab 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -1,76 +1,648 @@ """ -The config module is responsible for managing the configuration of zarr and is based on the Donfig python library. -For selecting custom implementations of codecs, pipelines, buffers and ndbuffers, first register the implementations -in the registry and then select them in the config. +Typed configuration for zarr. -Example: - An implementation of the bytes codec in a class ``your.module.NewBytesCodec`` requires the value of ``codecs.bytes`` - to be ``your.module.NewBytesCodec``. Donfig can be configured programmatically, by environment variables, or from - YAML files in standard locations. +The module exposes a single `config` object (a `ZarrConfigManager` instance) that +holds all runtime settings. Values can be read, overridden, and restored through a +simple string-key API: - ```python - from your.module import NewBytesCodec - from zarr.core.config import register_codec, config +- `config.get(key)` — read a dotted-key value (e.g. `config.get("async.concurrency")`). +- `config.set({key: value})` — permanent override; also usable as a context manager to + restore the previous state on exit. +- `config.reset()` — rebuild from defaults + environment. +- `config.refresh()` — alias for `reset`; called by the registry after env changes. +- `config.defaults` — nested dict of built-in default values. +- `config.enable_gpu()` — switch buffer/ndbuffer to GPU implementations. - register_codec("bytes", NewBytesCodec) - config.set({"codecs.bytes": "your.module.NewBytesCodec"}) - ``` +Environment variables use the `ZARR_` prefix and `__` for nesting: - Instead of setting the value programmatically with ``config.set``, you can also set the value with an environment - variable. The environment variable ``ZARR_CODECS__BYTES`` can be set to ``your.module.NewBytesCodec``. The double - underscore ``__`` is used to indicate nested access. +```bash +export ZARR_CODECS__BYTES="your.module.NewBytesCodec" +``` - ```bash - export ZARR_CODECS__BYTES="your.module.NewBytesCodec" - ``` +Programmatic override: -For more information, see the Donfig documentation at https://github.com/pytroll/donfig. +```python +from your.module import NewBytesCodec +from zarr.core.config import config + +config.set({"codecs.bytes": "your.module.NewBytesCodec"}) +``` + +For selecting custom implementations of codecs, pipelines, buffers, and ndbuffers, +register the implementation in the registry first, then set the path via `config.set`. """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, cast +import ast +import contextlib +import difflib +import os +import warnings +from collections.abc import Mapping +from contextvars import ContextVar, Token +from dataclasses import dataclass, field, fields, is_dataclass, replace +from typing import Any, Literal, Self, cast, overload -from donfig import Config as DConfig +from zarr.errors import ZarrDeprecationWarning, ZarrUserWarning -if TYPE_CHECKING: - from donfig.config_obj import ConfigSet +DEFAULT_CODECS: dict[str, str] = { + "blosc": "zarr.codecs.blosc.BloscCodec", + "gzip": "zarr.codecs.gzip.GzipCodec", + "zstd": "zarr.codecs.zstd.ZstdCodec", + "bytes": "zarr.codecs.bytes.BytesCodec", + "endian": "zarr.codecs.bytes.BytesCodec", + "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", + "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", + "transpose": "zarr.codecs.transpose.TransposeCodec", + "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", + "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", + "numcodecs.bz2": "zarr.codecs.numcodecs.BZ2", + "numcodecs.crc32": "zarr.codecs.numcodecs.CRC32", + "numcodecs.crc32c": "zarr.codecs.numcodecs.CRC32C", + "numcodecs.lz4": "zarr.codecs.numcodecs.LZ4", + "numcodecs.lzma": "zarr.codecs.numcodecs.LZMA", + "numcodecs.zfpy": "zarr.codecs.numcodecs.ZFPY", + "numcodecs.adler32": "zarr.codecs.numcodecs.Adler32", + "numcodecs.astype": "zarr.codecs.numcodecs.AsType", + "numcodecs.bitround": "zarr.codecs.numcodecs.BitRound", + "numcodecs.blosc": "zarr.codecs.numcodecs.Blosc", + "numcodecs.delta": "zarr.codecs.numcodecs.Delta", + "numcodecs.fixedscaleoffset": "zarr.codecs.numcodecs.FixedScaleOffset", + "numcodecs.fletcher32": "zarr.codecs.numcodecs.Fletcher32", + "numcodecs.gzip": "zarr.codecs.numcodecs.GZip", + "numcodecs.jenkins_lookup3": "zarr.codecs.numcodecs.JenkinsLookup3", + "numcodecs.pcodec": "zarr.codecs.numcodecs.PCodec", + "numcodecs.packbits": "zarr.codecs.numcodecs.PackBits", + "numcodecs.shuffle": "zarr.codecs.numcodecs.Shuffle", + "numcodecs.quantize": "zarr.codecs.numcodecs.Quantize", + "numcodecs.zlib": "zarr.codecs.numcodecs.Zlib", + "numcodecs.zstd": "zarr.codecs.numcodecs.Zstd", +} +# Map serialized dotted-key segments to Python field names where they differ +# (Python keywords cannot be used as identifiers). +_FIELD_ALIASES: dict[str, str] = {"async": "async_"} +_SERIALIZED_NAMES: dict[str, str] = {v: k for k, v in _FIELD_ALIASES.items()} -class BadConfigError(ValueError): - _msg = "bad Config: %r" + +@dataclass(frozen=True, slots=True) +class ArraySettings: + order: Literal["C", "F"] = "C" + write_empty_chunks: bool = False + read_missing_chunks: bool = True + target_shard_size_bytes: int | None = None + rectilinear_chunks: bool = False + sharding_coalesce_max_gap_bytes: int = 1 << 20 + sharding_coalesce_max_bytes: int = 16 << 20 + + +@dataclass(frozen=True, slots=True) +class AsyncSettings: + concurrency: int = 10 + timeout: float | None = None + + +@dataclass(frozen=True, slots=True) +class ThreadingSettings: + max_workers: int | None = None + + +@dataclass(frozen=True, slots=True) +class CodecPipelineSettings: + path: str = "zarr.core.codec_pipeline.BatchedCodecPipeline" + batch_size: int = 1 + + +@dataclass(frozen=True, slots=True) +class ZarrConfig: + default_zarr_format: Literal[2, 3] = 3 + array: ArraySettings = field(default_factory=ArraySettings) + async_: AsyncSettings = field(default_factory=AsyncSettings) + threading: ThreadingSettings = field(default_factory=ThreadingSettings) + json_indent: int = 2 + codec_pipeline: CodecPipelineSettings = field(default_factory=CodecPipelineSettings) + codecs: Mapping[str, str] = field(default_factory=lambda: dict(DEFAULT_CODECS)) + buffer: str = "zarr.buffer.cpu.Buffer" + ndbuffer: str = "zarr.buffer.cpu.NDBuffer" -class Config(DConfig): # type: ignore[misc] - """The Config will collect configuration from config files and environment variables +def make_default_config() -> ZarrConfig: + """Return a fresh `ZarrConfig` populated with the built-in defaults.""" + return ZarrConfig() - Example environment variables: - Grabs environment variables of the form "ZARR_FOO__BAR_BAZ=123" and - turns these into config variables of the form ``{"foo": {"bar-baz": 123}}`` - It transforms the key and value in the following way: - - Lower-cases the key text - - Treats ``__`` (double-underscore) as nested access - - Calls ``ast.literal_eval`` on the value +def _resolve_field(obj: object, segment: str) -> str: + """Translate a serialized key segment to the dataclass field name.""" + return _FIELD_ALIASES.get(segment, segment) + +def get_path(cfg: ZarrConfig, key: str) -> object: + """Read a dotted-string key from a `ZarrConfig` snapshot. + + Raises + ------ + KeyError + If the key does not resolve to a value. """ + obj: object = cfg + segments = key.split(".") + for i, segment in enumerate(segments): + if isinstance(obj, Mapping): + # remaining segments index into an open mapping (e.g. codecs.*) + remainder = ".".join(segments[i:]) + try: + return obj[remainder] + except KeyError: + raise KeyError(key) from None + field_name = _resolve_field(obj, segment) + if not hasattr(obj, field_name): + raise KeyError(key) + obj = getattr(obj, field_name) + return obj - def reset(self) -> None: - self.clear() - self.refresh() - def enable_gpu(self) -> ConfigSet: - """ - Configure Zarr to use GPUs where possible. +def replace_path(cfg: ZarrConfig, key: str, value: object) -> ZarrConfig: + """Return a new `ZarrConfig` with the dotted-string key set to ``value``.""" + segments = key.split(".") + return cast(ZarrConfig, _replace_recursive(cfg, segments, value, key)) + + +# `obj: Any` is load-bearing here: the function dispatches dynamically between a +# `Mapping` (codecs subtree) and a dataclass instance, and `dataclasses.replace` +# requires a dataclass-typed argument that `object` would reject. +def _replace_recursive(obj: Any, segments: list[str], value: object, key: str) -> object: + segment = segments[0] + if isinstance(obj, Mapping): + remainder = ".".join(segments) + return {**obj, remainder: value} + field_name = _resolve_field(obj, segment) + if not hasattr(obj, field_name): + raise KeyError(key) + if len(segments) == 1: + return replace(obj, **{field_name: value}) + child = getattr(obj, field_name) + new_child = _replace_recursive(child, segments[1:], value, key) + return replace(obj, **{field_name: new_child}) + + +_ROSTER_LIMIT = 10 + + +def _children(obj: object) -> list[str]: + """Return the immediate child key names of a config node (else an empty list).""" + if isinstance(obj, Mapping): + return list(obj) + if is_dataclass(obj): + return [_SERIALIZED_NAMES.get(f.name, f.name) for f in fields(obj)] + return [] + + +def _resolve_for_suggestion(cfg: ZarrConfig, key: str) -> tuple[str, list[str], str]: + """Walk ``key`` as far as it resolves. + + Returns the deepest resolvable dotted prefix, that node's child key names, + and the first segment that failed to resolve (the remainder is treated as a + single key once an open mapping like ``codecs`` is reached). For + ``"array.bogus"`` this is ``("array", [], "bogus")``; + for an unknown top-level key, ``("", [], )``. + """ + obj: object = cfg + prefix = "" + segments = key.split(".") + for i, segment in enumerate(segments): + if isinstance(obj, Mapping): + # the remainder indexes into an open mapping as a single key + return prefix, _children(obj), ".".join(segments[i:]) + field_name = _resolve_field(obj, segment) + if not hasattr(obj, field_name): + return prefix, _children(obj), segment + obj = getattr(obj, field_name) + prefix = f"{prefix}.{segment}" if prefix else segment + return prefix, _children(obj), "" + + +def _unknown_key_error(key: str, cfg: ZarrConfig) -> KeyError: + """Build a `KeyError` for an unknown config key. + + Resolves ``key`` to the deepest valid level, then suggests the closest child + key there if one is similar enough; otherwise lists the available keys at + that level (capped at `_ROSTER_LIMIT`). + """ + msg = f"{key!r} is not a valid configuration key." + prefix, children, failed = _resolve_for_suggestion(cfg, key) + matches = difflib.get_close_matches(failed, children, n=1) if failed != "" else [] + if len(matches) > 0: + suggestion = f"{prefix}.{matches[0]}" if prefix != "" else matches[0] + return KeyError(f"{msg} Did you mean {suggestion!r}?") + if len(children) > 0: + shown = sorted(children) + roster = ", ".join(shown[:_ROSTER_LIMIT]) + if len(shown) > _ROSTER_LIMIT: + roster += f", ... ({len(shown) - _ROSTER_LIMIT} more)" + where = f" under {prefix!r}" if prefix != "" else "" + msg = f"{msg} Valid keys{where}: {roster}." + return KeyError(msg) + + +def to_nested_dict(cfg: ZarrConfig) -> dict[str, Any]: + """Convert a `ZarrConfig` to a donfig-style nested dict (serialized keys). + + Returns a heterogeneous, JSON-like tree (nested dicts and scalars) that + callers navigate by key, so `Any` values are appropriate here. + """ + + # `obj: Any` is also load-bearing: `dataclasses.fields` requires a + # dataclass-typed argument that `object` would reject. + def convert(obj: Any) -> Any: + if isinstance(obj, Mapping): + return dict(obj) + if hasattr(type(obj), "__dataclass_fields__"): + out: dict[str, Any] = {} + for f in fields(obj): + serialized = _SERIALIZED_NAMES.get(f.name, f.name) + out[serialized] = convert(getattr(obj, f.name)) + return out + return obj + + return convert(cfg) # type: ignore[no-any-return] + + +ENV_PREFIX = "ZARR_" + +# Meta-variables that control WHERE config is loaded from, not config values themselves. +# These must be excluded from the env-override map to avoid spurious KeyErrors. +_ENV_META_VARS: frozenset[str] = frozenset({"ZARR_CONFIG"}) + + +def _parse_env_value(raw: str) -> object: + """Parse an env value with ``ast.literal_eval``; fall back to the raw string.""" + try: + return ast.literal_eval(raw) + except (ValueError, SyntaxError): + return raw + + +def collect_env(environ: Mapping[str, str]) -> dict[str, object]: + """Collect ``ZARR_*`` environment variables into a flat dotted-key map. + + ``ZARR_FOO__BAR_BAZ=1`` becomes ``{"foo.bar_baz": 1}`` — the key is + lower-cased and ``__`` denotes nested access. + + Variables listed in ``_ENV_META_VARS`` (e.g. ``ZARR_CONFIG``) are + directives about where config lives and are skipped. + """ + out: dict[str, object] = {} + for name, raw in environ.items(): + if not name.startswith(ENV_PREFIX): + continue + if name in _ENV_META_VARS: + continue + body = name[len(ENV_PREFIX) :] + dotted = body.lower().replace("__", ".") + out[dotted] = _parse_env_value(raw) + return out + + +def _config_search_paths(environ: Mapping[str, str]) -> list[str]: + """Standard YAML config locations, mirroring donfig's search order.""" + paths: list[str] = [] + env_path = environ.get("ZARR_CONFIG") + if env_path: + paths.append(env_path) + paths.append(os.path.join(os.path.expanduser("~"), ".config", "zarr")) + return paths + + +def collect_yaml(paths: list[str]) -> dict[str, object]: + """Merge YAML config files found at ``paths`` into a flat dotted-key map.""" + import yaml + + merged: dict[str, object] = {} + for path in paths: + candidates: list[str] = [] + if os.path.isdir(path): + candidates.extend( + os.path.join(path, fn) + for fn in sorted(os.listdir(path)) + if fn.endswith((".yaml", ".yml")) + ) + elif os.path.isfile(path): + candidates.append(path) + for candidate in candidates: + with contextlib.suppress(FileNotFoundError): + with open(candidate) as fh: + data = yaml.safe_load(fh) + if isinstance(data, Mapping): + merged.update(_flatten_mapping(data)) + return merged + + +def _flatten_mapping(data: Mapping[str, object], prefix: str = "") -> dict[str, object]: + out: dict[str, object] = {} + for k, v in data.items(): + key = f"{prefix}{k}" if not prefix else f"{prefix}.{k}" + if isinstance(v, Mapping): + out.update(_flatten_mapping(v, key)) + else: + out[key] = v + return out + + +def apply_overrides(cfg: ZarrConfig, overrides: Mapping[str, object]) -> ZarrConfig: + """Apply a flat dotted-key override map to a snapshot. + + Used exclusively by `build_config` for env/YAML ingest. Unknown keys are + skipped with a warning rather than raising, so a stray environment variable + or extra YAML key never prevents `import zarr` from succeeding. + """ + for key, value in overrides.items(): + try: + cfg = replace_path(cfg, key, value) + except KeyError: + warnings.warn( + f"Unrecognized zarr config key {key!r} from environment or YAML — ignoring.", + ZarrUserWarning, + stacklevel=2, + ) + return cfg + + +def build_config(environ: Mapping[str, str] | None = None) -> ZarrConfig: + """Build the base snapshot: defaults < YAML files < environment variables.""" + if environ is None: + environ = os.environ + return apply_overrides( + apply_overrides(make_default_config(), collect_yaml(_config_search_paths(environ))), + collect_env(environ), + ) + + +_MISSING = object() + + +class _ConfigSet: + """Context manager returned by ``ZarrConfigManager.set``. + + The change is applied immediately (permanent by default); using the object + as a ``with`` block restores the prior state on exit. + """ + + def __init__( + self, manager: ZarrConfigManager, prev_base: ZarrConfig, token: Token[ZarrConfig] + ) -> None: + self._manager = manager + self._prev_base = prev_base + self._token = token + + def __enter__(self) -> Self: + return self + + def __exit__(self, *exc: object) -> None: + self._manager._restore(self._prev_base, self._token) + + +class ZarrConfigManager: + """Typed, donfig-compatible configuration object.""" + + def __init__(self) -> None: + self._base: ZarrConfig = build_config() + self._scope: ContextVar[ZarrConfig] = ContextVar("zarr_config_scope") + + # --- state resolution ------------------------------------------------- + def _current(self) -> ZarrConfig: + return self._scope.get(self._base) + + def _restore(self, prev_base: ZarrConfig, token: Token[ZarrConfig]) -> None: + self._base = prev_base + self._scope.reset(token) + + # --- typed attribute access ------------------------------------------ + @property + def default_zarr_format(self) -> Literal[2, 3]: + return self._current().default_zarr_format + + @property + def array(self) -> ArraySettings: + return self._current().array + + @property + def async_(self) -> AsyncSettings: + return self._current().async_ + + @property + def threading(self) -> ThreadingSettings: + return self._current().threading + + @property + def codec_pipeline(self) -> CodecPipelineSettings: + return self._current().codec_pipeline + + @property + def json_indent(self) -> int: + return self._current().json_indent + + @property + def codecs(self) -> Mapping[str, str]: + return self._current().codecs + + @property + def buffer(self) -> str: + return self._current().buffer + + @property + def ndbuffer(self) -> str: + return self._current().ndbuffer + + # --- string API: get -------------------------------------------------- + @overload + def get(self, key: Literal["default_zarr_format"]) -> Literal[2, 3]: ... + @overload + def get(self, key: Literal["array.order"]) -> Literal["C", "F"]: ... + @overload + def get(self, key: Literal["array.write_empty_chunks"]) -> bool: ... + @overload + def get(self, key: Literal["array.read_missing_chunks"]) -> bool: ... + @overload + def get(self, key: Literal["array.target_shard_size_bytes"]) -> int | None: ... + @overload + def get(self, key: Literal["array.rectilinear_chunks"]) -> bool: ... + @overload + def get(self, key: Literal["array.sharding_coalesce_max_gap_bytes"]) -> int: ... + @overload + def get(self, key: Literal["array.sharding_coalesce_max_bytes"]) -> int: ... + @overload + def get(self, key: Literal["async.concurrency"]) -> int: ... + @overload + def get(self, key: Literal["async.timeout"]) -> float | None: ... + @overload + def get(self, key: Literal["threading.max_workers"]) -> int | None: ... + @overload + def get(self, key: Literal["json_indent"]) -> int: ... + @overload + def get(self, key: Literal["codec_pipeline.path"]) -> str: ... + @overload + def get(self, key: Literal["codec_pipeline.batch_size"]) -> int: ... + @overload + def get(self, key: Literal["buffer"]) -> str: ... + @overload + def get(self, key: Literal["ndbuffer"]) -> str: ... + @overload + # The fallback `-> Any` is deliberate: it lets `config.get("codecs", {})` be + # used as a mapping (e.g. `.get(name)` in the registry) and supports unknown + # keys. `object` here would force every such call site to narrow first. + def get(self, key: str, default: object = ...) -> Any: ... + + def get(self, key: str, default: object = _MISSING) -> Any: + resolved = self._apply_deprecation(key, raise_on_removed=False) + if resolved is None: + # Key was removed; treat as absent — honour the caller's default. + if default is _MISSING: + raise KeyError(key) + return default + current = self._current() + try: + return get_path(current, resolved) + except KeyError: + if default is _MISSING: + raise _unknown_key_error(key, current) from None + return default + + # --- string API: set -------------------------------------------------- + # + # NOTE: `set` accepts `Mapping[str, Any]`, so — unlike `get`, which is fully + # typed via per-key overloads — it does NOT statically validate values: + # `config.set({"array.order": "Q"})` is not a type error; it is caught at + # runtime instead. This is a deliberate, documented limitation. + # + # Static value typing would require an *open* TypedDict — declared structured + # keys validated by type, PLUS arbitrary `codecs.` string keys allowed + # (PEP 728 `extra_items`/`closed`). mypy (2.x) supports PEP 728 in no syntax + # and offers no feature flag for it. A *closed* TypedDict would instead reject + # the open codec-selection idiom + # `config.set({"codecs.bytes": "your.module.NewBytesCodec"})` and any + # dynamically built `dict[str, Any]` — a backwards-compatibility regression + # (the `codecs` namespace maps a codec name to a class path and is extended at + # runtime by users/plugins, so its keys cannot be enumerated statically). + # So `set` is intentionally permissive and validated at runtime: unknown + # structured keys raise (see `replace_path`), while `codecs.*` stays writable. + # + # REVISIT when mypy ships PEP 728 open-TypedDict support, or if zarr adopts a + # type checker that supports it (e.g. pyright's open/closed TypedDicts). At + # that point `set` can take an open TypedDict for static value validation + # while keeping `codecs.*` open. + def set(self, updates: Mapping[str, object] | None = None, **kwargs: object) -> _ConfigSet: + """Apply one or more config overrides. + + Accepts either a mapping of dotted keys to values, keyword arguments + (for top-level keys), or both:: + + config.set({"array.order": "F"}) + config.set(default_zarr_format=2) + + Unlike `get`, `set` does not statically type-check values: an invalid + value such as `config.set({"array.order": "Q"})` is reported at runtime, + not by the type checker. See the implementation comment above for the + rationale (the open `codecs.*` namespace prevents a precise TypedDict + under current mypy). """ + all_updates: dict[str, object] = {} + if updates: + all_updates.update(updates) + all_updates.update(kwargs) + prev_base = self._base + new = self._current() + for key, value in all_updates.items(): + resolved = self._apply_deprecation(key, raise_on_removed=True) + try: + new = replace_path(new, resolved, value) + except KeyError: + raise _unknown_key_error(key, new) from None + self._base = new + token = self._scope.set(new) + return _ConfigSet(self, prev_base, token) + + # --- lifecycle -------------------------------------------------------- + def reset(self) -> None: + self._base = build_config() + # Sync the scope so _current() returns the new base in this context. + self._scope.set(self._base) + + def refresh(self) -> None: + self._base = build_config() + # Sync the scope so the rebuilt base is visible in the calling context. + # Without this, any prior reset()/set() scope entry would shadow the refresh. + self._scope.set(self._base) + + def enable_gpu(self) -> _ConfigSet: return self.set( {"buffer": "zarr.buffer.gpu.Buffer", "ndbuffer": "zarr.buffer.gpu.NDBuffer"} ) + # --- compat / introspection ------------------------------------------ + @property + def defaults(self) -> dict[str, Any]: + return to_nested_dict(make_default_config()) + + def to_dict(self) -> dict[str, Any]: + return to_nested_dict(self._current()) + + def update(self, updates: Mapping[str, object]) -> None: + self.set(updates) + + def pprint(self) -> None: + import pprint as _pp + + _pp.pprint(self.to_dict()) + + # --- deprecations ----------------------------------------------------- + @overload + def _apply_deprecation(self, key: str, *, raise_on_removed: Literal[True]) -> str: ... + @overload + def _apply_deprecation(self, key: str, *, raise_on_removed: Literal[False]) -> str | None: ... + + def _apply_deprecation(self, key: str, *, raise_on_removed: bool) -> str | None: + """Resolve a possibly-deprecated config key. + + Parameters + ---------- + key : str + The dotted config key supplied by the caller. + raise_on_removed : bool + When `True` (used by `set`), raise `BadConfigError` if the key has been + removed. When `False` (used by `get`), return `None` instead so the + caller can treat the key as absent and honour the caller's default. + + Returns + ------- + str or None + The canonical (possibly redirected) key, or `None` when the key was + removed and `raise_on_removed` is `False`. + """ + if key not in deprecations: + return key + new_key = deprecations[key] + if new_key is None: + if raise_on_removed: + raise BadConfigError( + f"Configuration key {key!r} has been removed and no longer has any effect." + ) + return None + warnings.warn( + f"Configuration key {key!r} has been renamed to {new_key!r}.", + ZarrDeprecationWarning, + stacklevel=3, + ) + return new_key + + +class BadConfigError(ValueError): + _msg = "bad Config: %r" + # these keys were removed from the config as part of the 3.1.0 release. -# these deprecations should be removed in 3.1.1 or thereabouts. -deprecations = { +# These deprecations should be removed in 3.1.1 or thereabouts. +deprecations: dict[str, str | None] = { "array.v2_default_compressor.numeric": None, "array.v2_default_compressor.string": None, "array.v2_default_compressor.bytes": None, @@ -87,71 +659,12 @@ def enable_gpu(self) -> ConfigSet: "array.v3_default_compressors": None, } -# The default configuration for zarr -config = Config( - "zarr", - defaults=[ - { - "default_zarr_format": 3, - "array": { - "order": "C", - "write_empty_chunks": False, - "read_missing_chunks": True, - "target_shard_size_bytes": None, - "rectilinear_chunks": False, - "sharding_coalesce_max_gap_bytes": 1 << 20, # 1 MiB - "sharding_coalesce_max_bytes": 16 << 20, # 16 MiB - }, - "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, - "json_indent": 2, - "codec_pipeline": { - "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", - "batch_size": 1, - }, - "codecs": { - "blosc": "zarr.codecs.blosc.BloscCodec", - "gzip": "zarr.codecs.gzip.GzipCodec", - "zstd": "zarr.codecs.zstd.ZstdCodec", - "bytes": "zarr.codecs.bytes.BytesCodec", - "endian": "zarr.codecs.bytes.BytesCodec", # compatibility with earlier versions of ZEP1 - "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", - "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", - "transpose": "zarr.codecs.transpose.TransposeCodec", - "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", - "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", - "numcodecs.bz2": "zarr.codecs.numcodecs.BZ2", - "numcodecs.crc32": "zarr.codecs.numcodecs.CRC32", - "numcodecs.crc32c": "zarr.codecs.numcodecs.CRC32C", - "numcodecs.lz4": "zarr.codecs.numcodecs.LZ4", - "numcodecs.lzma": "zarr.codecs.numcodecs.LZMA", - "numcodecs.zfpy": "zarr.codecs.numcodecs.ZFPY", - "numcodecs.adler32": "zarr.codecs.numcodecs.Adler32", - "numcodecs.astype": "zarr.codecs.numcodecs.AsType", - "numcodecs.bitround": "zarr.codecs.numcodecs.BitRound", - "numcodecs.blosc": "zarr.codecs.numcodecs.Blosc", - "numcodecs.delta": "zarr.codecs.numcodecs.Delta", - "numcodecs.fixedscaleoffset": "zarr.codecs.numcodecs.FixedScaleOffset", - "numcodecs.fletcher32": "zarr.codecs.numcodecs.Fletcher32", - "numcodecs.gzip": "zarr.codecs.numcodecs.GZip", - "numcodecs.jenkins_lookup3": "zarr.codecs.numcodecs.JenkinsLookup3", - "numcodecs.pcodec": "zarr.codecs.numcodecs.PCodec", - "numcodecs.packbits": "zarr.codecs.numcodecs.PackBits", - "numcodecs.shuffle": "zarr.codecs.numcodecs.Shuffle", - "numcodecs.quantize": "zarr.codecs.numcodecs.Quantize", - "numcodecs.zlib": "zarr.codecs.numcodecs.Zlib", - "numcodecs.zstd": "zarr.codecs.numcodecs.Zstd", - }, - "buffer": "zarr.buffer.cpu.Buffer", - "ndbuffer": "zarr.buffer.cpu.NDBuffer", - } - ], - deprecations=deprecations, -) - - -def parse_indexing_order(data: Any) -> Literal["C", "F"]: +config = ZarrConfigManager() + + +def parse_indexing_order(data: object) -> Literal["C", "F"]: if data in ("C", "F"): - return cast("Literal['C', 'F']", data) + # the membership check narrows `data` to Literal["C", "F"] + return data msg = f"Expected one of ('C', 'F'), got {data} instead." raise ValueError(msg) diff --git a/tests/test_config.py b/tests/test_config.py index a758378dc7..1eac0a1253 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -22,7 +22,7 @@ from zarr.core.buffer import NDBuffer from zarr.core.buffer.core import Buffer from zarr.core.codec_pipeline import BatchedCodecPipeline -from zarr.core.config import BadConfigError, config +from zarr.core.config import DEFAULT_CODECS, BadConfigError, config from zarr.core.indexing import SelectorTuple from zarr.errors import ChunkNotFoundError, ZarrUserWarning from zarr.registry import ( @@ -45,66 +45,28 @@ def test_config_defaults_set() -> None: - # regression test for available defaults - assert ( - config.defaults - == [ - { - "default_zarr_format": 3, - "array": { - "order": "C", - "write_empty_chunks": False, - "read_missing_chunks": True, - "target_shard_size_bytes": None, - "rectilinear_chunks": False, - "sharding_coalesce_max_gap_bytes": 1 << 20, - "sharding_coalesce_max_bytes": 16 << 20, - }, - "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, - "json_indent": 2, - "codec_pipeline": { - "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", - "batch_size": 1, - }, - "codecs": { - "blosc": "zarr.codecs.blosc.BloscCodec", - "gzip": "zarr.codecs.gzip.GzipCodec", - "zstd": "zarr.codecs.zstd.ZstdCodec", - "bytes": "zarr.codecs.bytes.BytesCodec", - "endian": "zarr.codecs.bytes.BytesCodec", # compatibility with earlier versions of ZEP1 - "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", - "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", - "transpose": "zarr.codecs.transpose.TransposeCodec", - "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", - "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", - "numcodecs.bz2": "zarr.codecs.numcodecs.BZ2", - "numcodecs.crc32": "zarr.codecs.numcodecs.CRC32", - "numcodecs.crc32c": "zarr.codecs.numcodecs.CRC32C", - "numcodecs.lz4": "zarr.codecs.numcodecs.LZ4", - "numcodecs.lzma": "zarr.codecs.numcodecs.LZMA", - "numcodecs.zfpy": "zarr.codecs.numcodecs.ZFPY", - "numcodecs.adler32": "zarr.codecs.numcodecs.Adler32", - "numcodecs.astype": "zarr.codecs.numcodecs.AsType", - "numcodecs.bitround": "zarr.codecs.numcodecs.BitRound", - "numcodecs.blosc": "zarr.codecs.numcodecs.Blosc", - "numcodecs.delta": "zarr.codecs.numcodecs.Delta", - "numcodecs.fixedscaleoffset": "zarr.codecs.numcodecs.FixedScaleOffset", - "numcodecs.fletcher32": "zarr.codecs.numcodecs.Fletcher32", - "numcodecs.gzip": "zarr.codecs.numcodecs.GZip", - "numcodecs.jenkins_lookup3": "zarr.codecs.numcodecs.JenkinsLookup3", - "numcodecs.pcodec": "zarr.codecs.numcodecs.PCodec", - "numcodecs.packbits": "zarr.codecs.numcodecs.PackBits", - "numcodecs.shuffle": "zarr.codecs.numcodecs.Shuffle", - "numcodecs.quantize": "zarr.codecs.numcodecs.Quantize", - "numcodecs.zlib": "zarr.codecs.numcodecs.Zlib", - "numcodecs.zstd": "zarr.codecs.numcodecs.Zstd", - }, - "buffer": "zarr.buffer.cpu.Buffer", - "ndbuffer": "zarr.buffer.cpu.NDBuffer", - } - ] - ) + assert config.defaults == { + "default_zarr_format": 3, + "array": { + "order": "C", + "write_empty_chunks": False, + "read_missing_chunks": True, + "target_shard_size_bytes": None, + "rectilinear_chunks": False, + "sharding_coalesce_max_gap_bytes": 1 << 20, + "sharding_coalesce_max_bytes": 16 << 20, + }, + "async": {"concurrency": 10, "timeout": None}, + "threading": {"max_workers": None}, + "json_indent": 2, + "codec_pipeline": { + "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", + "batch_size": 1, + }, + "codecs": dict(DEFAULT_CODECS), + "buffer": "zarr.buffer.cpu.Buffer", + "ndbuffer": "zarr.buffer.cpu.NDBuffer", + } assert config.get("array.order") == "C" assert config.get("async.concurrency") == 10 assert config.get("async.timeout") is None @@ -156,9 +118,6 @@ def test_config_codec_pipeline_class(store: Store) -> None: # has default value assert get_pipeline_class().__name__ != "" - config.set({"codec_pipeline.name": "zarr.core.codec_pipeline.BatchedCodecPipeline"}) - assert get_pipeline_class() == zarr.core.codec_pipeline.BatchedCodecPipeline - _mock = Mock() class MockCodecPipeline(BatchedCodecPipeline): @@ -206,7 +165,7 @@ class MockEnvCodecPipeline(CodecPipeline): @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) def test_config_codec_implementation(store: Store) -> None: # has default value - assert fully_qualified_name(get_codec_class("blosc")) == config.defaults[0]["codecs"]["blosc"] + assert fully_qualified_name(get_codec_class("blosc")) == config.defaults["codecs"]["blosc"] _mock = Mock() @@ -259,7 +218,7 @@ def test_config_ndbuffer_implementation(store: Store) -> None: def test_config_buffer_implementation() -> None: # has default value - assert config.defaults[0]["buffer"] == "zarr.buffer.cpu.Buffer" + assert config.defaults["buffer"] == "zarr.buffer.cpu.Buffer" arr = zeros(shape=(100,), store=StoreExpectingTestBuffer()) diff --git a/tests/test_config_typed.py b/tests/test_config_typed.py new file mode 100644 index 0000000000..878562049a --- /dev/null +++ b/tests/test_config_typed.py @@ -0,0 +1,641 @@ +from __future__ import annotations + +import dataclasses +import typing +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from tests.conftest import Expect, ExpectFail +from zarr.core.config import ( + _SERIALIZED_NAMES, + DEFAULT_CODECS, + BadConfigError, + ZarrConfig, + ZarrConfigManager, + apply_overrides, + build_config, + collect_env, + get_path, + make_default_config, + replace_path, + to_nested_dict, +) + +if typing.TYPE_CHECKING: + import pathlib + +# --------------------------------------------------------------------------- +# Module-level constants used in parametrize lists (evaluated at collection time) +# --------------------------------------------------------------------------- + +_REMOVED_KEY = "array.v2_default_compressor.numeric" +_DEFAULT = make_default_config() + +# --------------------------------------------------------------------------- +# 1. get_path — success cases +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input="array.order", output="C", id="array-order"), + Expect(input="async.concurrency", output=10, id="async-concurrency-alias"), + Expect(input="json_indent", output=2, id="json-indent"), + Expect(input="codecs", output=DEFAULT_CODECS, id="codecs-dict"), + Expect(input="codecs.blosc", output="zarr.codecs.blosc.BloscCodec", id="codecs-blosc"), + ], + ids=lambda c: c.id, +) +def test_get_path(case: Expect[str, object]) -> None: + assert get_path(make_default_config(), case.input) == case.output + + +@pytest.mark.parametrize( + "case", + [ + ExpectFail(input="array.nonexistent", exception=KeyError, id="nonexistent-key"), + ], + ids=lambda c: c.id, +) +def test_get_path_raises(case: ExpectFail[str]) -> None: + with case.raises(): + get_path(make_default_config(), case.input) + + +# --------------------------------------------------------------------------- +# 2. replace_path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=("array.order", "F"), output="F", id="array-order"), + Expect(input=("async.concurrency", 99), output=99, id="async-concurrency-alias"), + Expect( + input=("codecs.my_codec", "my.module.MyCodec"), + output="my.module.MyCodec", + id="codec-new-key", + ), + ], + ids=lambda c: c.id, +) +def test_replace_path(case: Expect[tuple[str, object], object]) -> None: + key, value = case.input + result = replace_path(make_default_config(), key, value) + assert get_path(result, key) == case.output + + +def test_replace_path_is_immutable() -> None: + """Original config is unchanged after replace_path (frozen dataclass).""" + cfg = make_default_config() + _ = replace_path(cfg, "array.order", "F") + assert cfg.array.order == "C" + # the open `codecs` dict must not be mutated in place either: a frozen + # dataclass forbids attribute re-assignment but not `dict.__setitem__`. + _ = replace_path(cfg, "codecs.my_codec", "my.module.MyCodec") + assert "my_codec" not in cfg.codecs + + +# --------------------------------------------------------------------------- +# 3. collect_env +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input={ + "ZARR_ARRAY__ORDER": "F", + "ZARR_ASYNC__CONCURRENCY": "32", + "ZARR_CODECS__MY_CODEC": "my.module.MyCodec", + "UNRELATED": "ignored", + }, + output={ + "array.order": "F", + "async.concurrency": 32, + "codecs.my_codec": "my.module.MyCodec", + }, + id="nested-and-literal", + ), + Expect( + input={"ZARR_CONFIG": "/some/path.yaml", "ZARR_ARRAY__ORDER": "F"}, + output={"array.order": "F"}, + id="zarr-config-meta-var-skipped", + ), + ], + ids=lambda c: c.id, +) +def test_collect_env(case: Expect[dict[str, str], dict[str, object]]) -> None: + assert collect_env(case.input) == case.output + + +# --------------------------------------------------------------------------- +# 4. build_config +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input={}, output=_DEFAULT, id="empty-environ"), + Expect( + input={"ZARR_CONFIG": "/nonexistent/path.yaml"}, + output=_DEFAULT, + id="zarr-config-nonexistent", + ), + Expect( + input={"ZARR_JSON_INDENT": "4"}, + output=replace_path(_DEFAULT, "json_indent", 4), + id="json-indent-env", + ), + ], + ids=lambda c: c.id, +) +def test_build_config(case: Expect[dict[str, str], ZarrConfig]) -> None: + assert build_config(environ=case.input) == case.output + + +# --------------------------------------------------------------------------- +# 5. apply_overrides +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input={"array.order": "F", "codecs.x": "pkg.X"}, + output=replace_path(replace_path(_DEFAULT, "array.order", "F"), "codecs.x", "pkg.X"), + id="array-order-and-codec", + ), + ], + ids=lambda c: c.id, +) +def test_apply_overrides(case: Expect[dict[str, object], ZarrConfig]) -> None: + assert apply_overrides(build_config(environ={}), case.input) == case.output + + +# --------------------------------------------------------------------------- +# 6. to_nested_dict +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=make_default_config(), + output=("C", 10, "zarr.codecs.blosc.BloscCodec"), + id="default-serialized-keys", + ), + ], + ids=lambda c: c.id, +) +def test_to_nested_dict(case: Expect[ZarrConfig, tuple[str, int, str]]) -> None: + nested = to_nested_dict(case.input) + order, concurrency, blosc = case.output + assert nested["array"]["order"] == order + assert nested["async"]["concurrency"] == concurrency + assert "async_" not in nested # serialized key, not the Python attribute name + assert nested["codecs"]["blosc"] == blosc + + +# --------------------------------------------------------------------------- +# 7. ZarrConfigManager.get — proxy string access +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input="array.order", output="C", id="array-order"), + Expect(input="async.concurrency", output=10, id="async-concurrency-alias"), + Expect(input="codecs", output=DEFAULT_CODECS, id="codecs-dict"), + ], + ids=lambda c: c.id, +) +def test_proxy_get(case: Expect[str, object]) -> None: + assert ZarrConfigManager().get(case.input) == case.output + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=("does.not.exist", "fallback"), output="fallback", id="default-fallback"), + ], + ids=lambda c: c.id, +) +def test_proxy_get_with_default(case: Expect[tuple[str, object], object]) -> None: + key, default = case.input + assert ZarrConfigManager().get(key, default) == case.output + + +# --------------------------------------------------------------------------- +# 8. Removed-deprecated-key behavior +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input="fallback", output="fallback", id="get-with-default"), + ], + ids=lambda c: c.id, +) +def test_removed_deprecated_key_get_default(case: Expect[str, str]) -> None: + """get() with a removed deprecated key and a default returns the default silently.""" + assert ZarrConfigManager().get(_REMOVED_KEY, case.input) == case.output + + +@pytest.mark.parametrize( + "case", + [ + ExpectFail(input=_REMOVED_KEY, exception=KeyError, id="get-no-default"), + ], + ids=lambda c: c.id, +) +def test_removed_deprecated_key_get_raises(case: ExpectFail[str]) -> None: + """get() with a removed deprecated key and no default raises KeyError.""" + with case.raises(): + ZarrConfigManager().get(case.input) + + +# --------------------------------------------------------------------------- +# 9. set() must raise for both removed-deprecated keys and totally unknown keys +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectFail( + input={_REMOVED_KEY: "some_value"}, + exception=BadConfigError, + id="set-removed-deprecated", + ), + ExpectFail( + input={"totally.bogus.key": 1}, + exception=KeyError, + id="set-unknown-key", + ), + ], + ids=lambda c: c.id, +) +def test_set_invalid_key_raises(case: ExpectFail[dict[str, object]]) -> None: + """set() raises for both removed deprecated keys and totally unknown structured keys.""" + with case.raises(): + ZarrConfigManager().set(case.input) + + +# --------------------------------------------------------------------------- +# 10. Unknown keys produce a helpful "did you mean" message (get and set) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + # close match at the deepest resolvable level -> "Did you mean ...?" + ExpectFail( + input="arr4y", exception=KeyError, msg=r"Did you mean .array.", id="suggest-top" + ), + ExpectFail( + input="array.0rder", + exception=KeyError, + msg=r"Did you mean .array\.order.", + id="suggest-nested", + ), + ExpectFail( + input="codecs.bl0sc", + exception=KeyError, + msg=r"Did you mean .codecs\.blosc.", + id="suggest-codec", + ), + # no close match -> roster of available keys at the last resolvable level + ExpectFail(input="foo", exception=KeyError, msg=r"Valid keys: .*array", id="roster-top"), + ExpectFail( + input="array.foo", + exception=KeyError, + msg=r"Valid keys under .array.: .*order", + id="roster-nested", + ), + ExpectFail( + input="codecs.zzzzzzzz", + exception=KeyError, + msg=r"under .codecs.: .*more\)", + id="roster-truncated", + ), + ], + ids=lambda c: c.id, +) +def test_get_unknown_key_message(case: ExpectFail[str]) -> None: + """get() on an unknown key suggests the closest key or lists what's available.""" + with case.raises(): + ZarrConfigManager().get(case.input) + + +@pytest.mark.parametrize( + "case", + [ + ExpectFail( + input={"array.0rder": "F"}, + exception=KeyError, + msg=r"Did you mean .array\.order.", + id="set-suggest", + ), + ExpectFail( + input={"array.foo": "F"}, + exception=KeyError, + msg=r"Valid keys under .array.: .*order", + id="set-roster", + ), + ], + ids=lambda c: c.id, +) +def test_set_unknown_key_message(case: ExpectFail[dict[str, object]]) -> None: + """set() shares the same helpful unknown-key error as get().""" + with case.raises(): + ZarrConfigManager().set(case.input) + + +# --------------------------------------------------------------------------- +# Default config values (dedicated — direct attribute assertions are clearest here) +# --------------------------------------------------------------------------- + + +def test_default_config_values() -> None: + cfg = make_default_config() + assert cfg.default_zarr_format == 3 + assert cfg.array.order == "C" + assert cfg.array.sharding_coalesce_max_bytes == 16 << 20 + assert cfg.async_.concurrency == 10 + assert cfg.async_.timeout is None + assert cfg.threading.max_workers is None + assert cfg.json_indent == 2 + assert cfg.codec_pipeline.path == "zarr.core.codec_pipeline.BatchedCodecPipeline" + assert cfg.codecs["blosc"] == "zarr.codecs.blosc.BloscCodec" + assert cfg.codecs == DEFAULT_CODECS + # proxy attribute access via ZarrConfigManager + mgr = ZarrConfigManager() + assert mgr.array.order == "C" + + +# --------------------------------------------------------------------------- +# Stateful / behavioral tests (kept as dedicated functions) +# --------------------------------------------------------------------------- + + +def test_set_permanent_and_context() -> None: + cfg = ZarrConfigManager() + cfg.set({"array.order": "F"}) + assert cfg.get("array.order") == "F" # permanent + with cfg.set({"array.order": "C"}): + assert cfg.get("array.order") == "C" + assert cfg.get("array.order") == "F" # restored to permanent value + cfg.reset() + assert cfg.get("array.order") == "C" + + +def test_permanent_set_visible_in_worker_thread() -> None: + cfg = ZarrConfigManager() + cfg.set({"async.concurrency": 77}) + try: + with ThreadPoolExecutor(max_workers=1) as ex: + seen = ex.submit(lambda: cfg.get("async.concurrency")).result() + assert seen == 77 # ThreadPoolExecutor does not copy contextvars + finally: + cfg.reset() + + +def test_defaults_and_enable_gpu() -> None: + cfg = ZarrConfigManager() + assert cfg.defaults["array"]["order"] == "C" + with cfg.set({"buffer": "x"}): + pass + cfg.enable_gpu() + try: + assert cfg.get("buffer") == "zarr.buffer.gpu.Buffer" + assert cfg.get("ndbuffer") == "zarr.buffer.gpu.NDBuffer" + finally: + cfg.reset() + + +def test_refresh_not_shadowed_by_prior_scope(monkeypatch: pytest.MonkeyPatch) -> None: + """refresh() must be visible in the calling context even after a prior set()/reset().""" + mgr = ZarrConfigManager() + # plant a scope entry in this thread/context (as reset()/set() would) + mgr.set({"array.order": "F"}) + assert mgr.get("array.order") == "F" + # change the environment so a rebuild differs, then refresh + monkeypatch.setenv("ZARR_JSON_INDENT", "7") + mgr.refresh() + # refresh must be visible in THIS context, not shadowed by the prior scope + assert mgr.get("json_indent") == 7 + assert mgr.get("array.order") == "C" # the prior permanent set is gone after rebuild + + +# --------------------------------------------------------------------------- +# Tolerant ingest: unknown env/YAML keys must warn and be skipped, not crash +# --------------------------------------------------------------------------- + + +def test_build_config_unknown_env_key_warns_and_skips() -> None: + """build_config with an unrecognized env var warns and skips it; known keys still apply.""" + with pytest.warns(UserWarning, match="future.key"): + cfg = build_config(environ={"ZARR_FUTURE__KEY": "1", "ZARR_ARRAY__ORDER": "F"}) + # Known key was applied + assert cfg.array.order == "F" + # All other fields are still at default + default = make_default_config() + from dataclasses import fields as dc_fields + + for f in dc_fields(default): + if f.name != "array": + assert getattr(cfg, f.name) == getattr(default, f.name) + + +def test_apply_overrides_unknown_key_warns_and_returns_default() -> None: + """apply_overrides with a totally unknown key warns and returns an otherwise-default config.""" + default = make_default_config() + with pytest.warns(UserWarning, match="totally.bogus.key"): + result = apply_overrides(default, {"totally.bogus.key": 123}) + assert result == default + + +# --------------------------------------------------------------------------- +# donfig not imported +# --------------------------------------------------------------------------- + + +def test_donfig_not_imported() -> None: + import sys + + import zarr # noqa: F401 + + assert "donfig" not in sys.modules + + +# --------------------------------------------------------------------------- +# YAML codec block merging — regression for the "wipes all defaults" bug +# --------------------------------------------------------------------------- + + +def test_yaml_codecs_block_merges_not_replaces(tmp_path: pathlib.Path) -> None: + """A YAML file with a codecs: block must MERGE into the defaults, not replace them.""" + yaml_file = tmp_path / "zarr.yaml" + yaml_file.write_text("codecs:\n bytes: my.custom.BytesCodec\n mycodec: my.Mod.MyCodec\n") + cfg = build_config(environ={"ZARR_CONFIG": str(yaml_file)}) + # overrides applied + assert cfg.codecs["bytes"] == "my.custom.BytesCodec" + assert cfg.codecs["mycodec"] == "my.Mod.MyCodec" + # defaults PRESERVED + assert cfg.codecs["blosc"] == "zarr.codecs.blosc.BloscCodec" + assert cfg.codecs["zstd"] == "zarr.codecs.zstd.ZstdCodec" + # exactly one net-new key added ("bytes" overwrites existing; "mycodec" is new) + assert len(cfg.codecs) == len(DEFAULT_CODECS) + 1 + + +def test_yaml_dotted_codec_name_merges(tmp_path: pathlib.Path) -> None: + """Dotted codec keys like numcodecs.bz2 in YAML must merge, not replace the whole dict.""" + yaml_file = tmp_path / "zarr.yaml" + yaml_file.write_text("codecs:\n numcodecs.bz2: my.Override\n") + cfg = build_config(environ={"ZARR_CONFIG": str(yaml_file)}) + # dotted key correctly round-tripped + assert cfg.codecs["numcodecs.bz2"] == "my.Override" + # all other defaults preserved + assert cfg.codecs["blosc"] == "zarr.codecs.blosc.BloscCodec" + assert len(cfg.codecs) == len(DEFAULT_CODECS) # bz2 was already there; just overwritten + + +def test_build_config_environ_yaml_path_is_read(tmp_path: pathlib.Path) -> None: + """ZARR_CONFIG supplied via build_config(environ=...) must actually be read.""" + yaml_file = tmp_path / "zarr.yaml" + yaml_file.write_text("json_indent: 9\n") + cfg = build_config(environ={"ZARR_CONFIG": str(yaml_file)}) + assert cfg.json_indent == 9 + # Non-existent path must still not raise + cfg2 = build_config(environ={"ZARR_CONFIG": "/nonexistent/path.yaml"}) + assert cfg2.json_indent == make_default_config().json_indent + + +# --------------------------------------------------------------------------- +# Drift-protection: every structured leaf key must have a get() overload +# --------------------------------------------------------------------------- + + +def _structured_leaf_specs(cfg_cls: type, prefix: str = "") -> dict[str, object]: + """Walk a settings dataclass recursively and return ``{dotted_key: resolved_type}``. + + Uses ``typing.get_type_hints`` instead of ``f.type`` so that the + ``from __future__ import annotations`` string-annotation form is resolved + to real types before ``dataclasses.is_dataclass`` is called. The open + ``codecs`` mapping is intentionally excluded. + """ + specs: dict[str, object] = {} + resolved_hints = typing.get_type_hints(cfg_cls) + for f in dataclasses.fields(cfg_cls): + serialized = _SERIALIZED_NAMES.get(f.name, f.name) + key = f"{prefix}.{serialized}" if prefix else serialized + resolved_type = resolved_hints[f.name] + if dataclasses.is_dataclass(resolved_type): + specs.update(_structured_leaf_specs(typing.cast(type, resolved_type), key)) + elif f.name == "codecs": + # open mapping — intentionally not enumerated + continue + else: + specs[key] = resolved_type + return specs + + +def _structured_leaf_keys(cfg_cls: type, prefix: str = "") -> list[str]: + """Return every dotted leaf key for a settings dataclass (derived from specs).""" + return list(_structured_leaf_specs(cfg_cls, prefix)) + + +def test_every_structured_key_has_a_get_overload() -> None: + """Enumerate every typed leaf key in ZarrConfig and assert a matching get() overload exists.""" + overloads = typing.get_overloads(ZarrConfigManager.get) + literal_keys: set[str] = set() + for ov in overloads: + hints = typing.get_type_hints(ov) + key_hint = hints.get("key") + if typing.get_origin(key_hint) is typing.Literal: + literal_keys.update(typing.get_args(key_hint)) + leaf_keys = _structured_leaf_keys(ZarrConfig) + missing = set(leaf_keys) - literal_keys + assert not missing, f"get() overloads missing for: {sorted(missing)}" + + +def test_get_overload_return_types_match_fields() -> None: + """Assert that each get() overload's return type matches the dataclass field type. + + Builds two maps using ``typing.get_type_hints`` — one from the dataclass + field annotations, one from the overload return hints — then compares them + key by key. A mismatch (e.g. ``-> str`` instead of ``-> Literal["C","F"]``) + is reported as a clear failure rather than a missing-overload failure. + """ + # Build map: key -> return type from overloads + overloads = typing.get_overloads(ZarrConfigManager.get) + overload_return: dict[str, object] = {} + for ov in overloads: + hints = typing.get_type_hints(ov) + key_hint = hints.get("key") + if typing.get_origin(key_hint) is typing.Literal: + (literal_val,) = typing.get_args(key_hint) + overload_return[literal_val] = hints["return"] + + # Build map: key -> field type from the dataclass schema + field_specs = _structured_leaf_specs(ZarrConfig) + + missing: list[str] = [] + mismatched: list[str] = [] + for key, expected_type in field_specs.items(): + if key not in overload_return: + missing.append(f" {key!r}: missing overload") + elif overload_return[key] != expected_type: + mismatched.append( + f" {key!r}: overload returns {overload_return[key]!r}," + f" field type is {expected_type!r}" + ) + + errors: list[str] = [] + if missing: + errors.append("get() overloads missing for keys:\n" + "\n".join(missing)) + if mismatched: + errors.append( + "get() overload return types do not match field types:\n" + "\n".join(mismatched) + ) + assert not errors, "\n\n".join(errors) + + +# --------------------------------------------------------------------------- +# Static-typing smoke test (only checked by mypy, not executed at runtime) +# --------------------------------------------------------------------------- + +if typing.TYPE_CHECKING: + + def _typing_smoke(cfg: ZarrConfigManager) -> None: + # --- positive assertions: each distinct return shape --- + typing.assert_type(cfg.get("array.order"), typing.Literal["C", "F"]) + typing.assert_type(cfg.get("async.concurrency"), int) + typing.assert_type(cfg.get("array.write_empty_chunks"), bool) + typing.assert_type(cfg.get("async.timeout"), float | None) + typing.assert_type(cfg.get("threading.max_workers"), int | None) + typing.assert_type(cfg.get("default_zarr_format"), typing.Literal[2, 3]) + typing.assert_type(cfg.get("buffer"), str) + typing.assert_type(cfg.array.order, typing.Literal["C", "F"]) + + # --- negative: precision-from-above guards --- + # The return type is Literal["C","F"], which is narrower than str. + # If the overload were widened to -> str, assert_type would pass and + # the ignore below would become unused, causing warn_unused_ignores to + # fail CI. + typing.assert_type(cfg.get("array.order"), str) # type: ignore[assert-type] + typing.assert_type(cfg.get("default_zarr_format"), int) # type: ignore[assert-type] + + # --- negative: bad key type must be rejected by all overloads --- + cfg.get(123) # type: ignore[call-overload] diff --git a/uv.lock b/uv.lock index 799ea6e45a..a39badf7ee 100644 --- a/uv.lock +++ b/uv.lock @@ -915,18 +915,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" }, ] -[[package]] -name = "donfig" -version = "0.8.1.post1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, -] - [[package]] name = "execnet" version = "2.1.2" @@ -3954,11 +3942,11 @@ wheels = [ name = "zarr" source = { editable = "." } dependencies = [ - { name = "donfig" }, { name = "google-crc32c" }, { name = "numcodecs" }, { name = "numpy" }, { name = "packaging" }, + { name = "pyyaml" }, { name = "typing-extensions" }, ] @@ -4075,13 +4063,13 @@ test = [ requires-dist = [ { name = "cast-value-rs", marker = "extra == 'cast-value-rs'" }, { name = "cupy-cuda12x", marker = "sys_platform != 'darwin' and extra == 'gpu'" }, - { name = "donfig", specifier = ">=0.8" }, { name = "fsspec", marker = "extra == 'remote'", specifier = ">=2023.10.0" }, { name = "google-crc32c", specifier = ">=1.5" }, { name = "numcodecs", specifier = ">=0.14" }, { name = "numpy", specifier = ">=2" }, { name = "obstore", marker = "extra == 'remote'", specifier = ">=0.5.1" }, { name = "packaging", specifier = ">=22.0" }, + { name = "pyyaml", specifier = ">=6" }, { name = "typer", marker = "extra == 'cli'" }, { name = "typing-extensions", specifier = ">=4.14" }, { name = "universal-pathlib", marker = "extra == 'optional'" },