Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_text_embeddings_config_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import ast
from pathlib import Path


def load_config_value_matches():
source_path = Path(__file__).resolve().parents[1] / "text_embeddings_connectors.py"
source = ast.parse(source_path.read_text())
helper = next(
node
for node in source.body
if isinstance(node, ast.FunctionDef) and node.name == "_config_value_matches"
)
module = ast.Module(body=[helper], type_ignores=[])
ast.fix_missing_locations(module)
namespace = {}
exec(compile(module, str(source_path), "exec"), namespace)
return namespace["_config_value_matches"]


def test_text_encoder_norm_type_matches_both_casings():
matches = load_config_value_matches()

assert matches("per_token_rms", "per_token_rms")
assert matches("PER_TOKEN_RMS", "per_token_rms")


def test_non_string_expectations_remain_strict():
matches = load_config_value_matches()

assert matches(False, False)
assert not matches("False", False)
9 changes: 8 additions & 1 deletion text_embeddings_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def _filter_sd(sd: dict, prefix: str) -> dict:
return {k[len(prefix) :]: v for k, v in sd.items() if k.startswith(prefix)}


def _config_value_matches(actual, expected_val) -> bool:
"""Compare string config values case-insensitively while keeping other types strict."""
if isinstance(actual, str) and isinstance(expected_val, str):
return actual.casefold() == expected_val.casefold()
return actual == expected_val


def _load_aggregate_embed(sd: dict, modality: str, dtype) -> nn.Linear:
"""Load an aggregate_embed Linear from the state dict.

Expand Down Expand Up @@ -359,7 +366,7 @@ def load_text_embeddings_pipeline(
}
for key, expected_val in _expected.items():
actual = transformer_config.get(key)
assert actual == expected_val, (
assert _config_value_matches(actual, expected_val), (
f"Unexpected config for dual-aggregate model: "
f"{key}={actual!r}, expected {expected_val!r}"
)
Expand Down