diff --git a/tests/test_text_embeddings_config_validation.py b/tests/test_text_embeddings_config_validation.py new file mode 100644 index 0000000..dd3627e --- /dev/null +++ b/tests/test_text_embeddings_config_validation.py @@ -0,0 +1,31 @@ +import ast +from pathlib import Path + + +def load_config_value_matches(): + source_path = Path(__file__).resolve().parents[1] / "text_embeddings_connectors.py" + source = ast.parse(source_path.read_text()) + helper = next( + node + for node in source.body + if isinstance(node, ast.FunctionDef) and node.name == "_config_value_matches" + ) + module = ast.Module(body=[helper], type_ignores=[]) + ast.fix_missing_locations(module) + namespace = {} + exec(compile(module, str(source_path), "exec"), namespace) + return namespace["_config_value_matches"] + + +def test_text_encoder_norm_type_matches_both_casings(): + matches = load_config_value_matches() + + assert matches("per_token_rms", "per_token_rms") + assert matches("PER_TOKEN_RMS", "per_token_rms") + + +def test_non_string_expectations_remain_strict(): + matches = load_config_value_matches() + + assert matches(False, False) + assert not matches("False", False) diff --git a/text_embeddings_connectors.py b/text_embeddings_connectors.py index a5a7688..a788b64 100644 --- a/text_embeddings_connectors.py +++ b/text_embeddings_connectors.py @@ -47,6 +47,13 @@ def _filter_sd(sd: dict, prefix: str) -> dict: return {k[len(prefix) :]: v for k, v in sd.items() if k.startswith(prefix)} +def _config_value_matches(actual, expected_val) -> bool: + """Compare string config values case-insensitively while keeping other types strict.""" + if isinstance(actual, str) and isinstance(expected_val, str): + return actual.casefold() == expected_val.casefold() + return actual == expected_val + + def _load_aggregate_embed(sd: dict, modality: str, dtype) -> nn.Linear: """Load an aggregate_embed Linear from the state dict. @@ -359,7 +366,7 @@ def load_text_embeddings_pipeline( } for key, expected_val in _expected.items(): actual = transformer_config.get(key) - assert actual == expected_val, ( + assert _config_value_matches(actual, expected_val), ( f"Unexpected config for dual-aggregate model: " f"{key}={actual!r}, expected {expected_val!r}" )