feat: add detection workflow column plugins#192
Conversation
|
Added a local parity check for the detection-column migration. This was intentionally run outside the committed diff: it compares the old Parity harnessfrom __future__ import annotations
import copy
import json
from types import SimpleNamespace
from typing import Any
from data_designer.config.column_configs import CustomColumnConfig
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
from anonymizer.engine.constants import (
COL_AUGMENTED_ENTITIES,
COL_DETECTED_ENTITIES,
COL_MERGED_ENTITIES,
COL_RAW_DETECTED,
COL_SEED_ENTITIES,
COL_SEED_ENTITIES_JSON,
COL_SEED_VALIDATION_CANDIDATES,
COL_TEXT,
COL_VALIDATED_ENTITIES,
COL_VALIDATION_DECISIONS,
)
from anonymizer.engine.detection.chunked_validation import ChunkedValidationParams, make_chunked_validation_generator
from anonymizer.engine.detection.custom_columns import (
apply_validation_and_finalize,
apply_validation_to_seed_entities,
enrich_validation_decisions,
merge_and_build_candidates,
parse_detected_entities,
prepare_validation_inputs,
)
from anonymizer.engine.workflow_columns.detection.config import (
ChunkedValidationConfig,
DetectionTransformConfig,
DetectionTransformOperation,
)
from anonymizer.engine.workflow_columns.detection.impl import ChunkedValidationGenerator, DetectionTransformGenerator
class FakeFacade:
def __init__(self, response: dict[str, Any]) -> None:
self.response = response
self.calls: list[dict[str, Any]] = []
def generate(self, *, prompt: str, parser, system_prompt: str | None = None, purpose: str | None = None, **kwargs):
self.calls.append(
{
"prompt": prompt,
"system_prompt": system_prompt,
"purpose": purpose,
"kwargs": kwargs,
}
)
raw = f"```json\n{json.dumps(self.response)}\n```"
return parser(raw), []
class FakeModelRegistry:
def __init__(self, facade: FakeFacade) -> None:
self.facade = facade
def get_model(self, *, model_alias: str) -> FakeFacade:
assert model_alias == "v0"
return self.facade
def resource_provider(facade: FakeFacade) -> SimpleNamespace:
return SimpleNamespace(model_registry=FakeModelRegistry(facade))
def custom_generator(name: str, fn, **kwargs: Any) -> CustomColumnGenerator:
return CustomColumnGenerator(
CustomColumnConfig(name=name, generator_function=fn, **kwargs),
resource_provider=SimpleNamespace(),
)
def compare_transform(
label: str,
*,
name: str,
fn,
operation: DetectionTransformOperation,
row: dict[str, Any],
) -> dict[str, Any]:
old = custom_generator(name, fn).generate(copy.deepcopy(row))
new = DetectionTransformGenerator(
DetectionTransformConfig(name=name, operation=operation),
resource_provider=SimpleNamespace(),
).generate(copy.deepcopy(row))
assert old == new, f"{label} mismatch:\nold={old}\nnew={new}"
print(f"PASS {label}")
return old
def compare_chunked_validation(row: dict[str, Any]) -> None:
response = {
"decisions": [
{"id": "first_name_0_5", "decision": "keep", "reason": "correct"},
{"id": "city_15_22", "decision": "keep", "reason": "correct"},
]
}
params = ChunkedValidationParams(
pool=["v0"],
max_entities_per_call=10,
excerpt_window_chars=20,
prompt_template=(
"Tagged: {{ _seed_tagged_text }}\n"
"Skeleton: {{ _validation_skeleton }}\n"
"Notation: {{ _tag_notation }}"
),
)
old_facade = FakeFacade(response)
old = CustomColumnGenerator(
CustomColumnConfig(
name=COL_VALIDATION_DECISIONS,
generator_function=make_chunked_validation_generator(["v0"]),
generator_params=params,
drop=True,
),
resource_provider=resource_provider(old_facade),
).generate(copy.deepcopy(row))
new_facade = FakeFacade(response)
new = ChunkedValidationGenerator(
ChunkedValidationConfig(
name=COL_VALIDATION_DECISIONS,
pool=["v0"],
max_entities_per_call=params.max_entities_per_call,
excerpt_window_chars=params.excerpt_window_chars,
prompt_template=params.prompt_template,
drop=True,
),
resource_provider=resource_provider(new_facade),
).generate(copy.deepcopy(row))
assert old == new, f"chunked validation mismatch:\nold={old}\nnew={new}"
assert old_facade.calls[0]["purpose"] == new_facade.calls[0]["purpose"]
print("PASS chunked validation")
def main() -> None:
text = "Alice moved to Seattle."
row: dict[str, Any] = {
COL_TEXT: text,
COL_RAW_DETECTED: json.dumps(
{
"entities": [
{"text": "Alice", "label": "first_name", "start": 0, "end": 5, "score": 0.99},
{"text": "Seattle", "label": "city", "start": 15, "end": 22, "score": 0.96},
]
}
),
}
row = compare_transform(
"parse_detected_entities",
name=COL_SEED_ENTITIES,
fn=parse_detected_entities,
operation=DetectionTransformOperation.PARSE_DETECTED_ENTITIES,
row=row,
)
row = compare_transform(
"prepare_validation_inputs",
name=COL_SEED_VALIDATION_CANDIDATES,
fn=prepare_validation_inputs,
operation=DetectionTransformOperation.PREPARE_VALIDATION_INPUTS,
row=row,
)
compare_chunked_validation(row)
row[COL_VALIDATION_DECISIONS] = {
"decisions": [
{"id": "first_name_0_5", "decision": "keep", "reason": "correct"},
{"id": "city_15_22", "decision": "reclass", "proposed_label": "location", "reason": "more precise"},
]
}
row = compare_transform(
"enrich_validation_decisions",
name=COL_VALIDATED_ENTITIES,
fn=enrich_validation_decisions,
operation=DetectionTransformOperation.ENRICH_VALIDATION_DECISIONS,
row=row,
)
row = compare_transform(
"apply_validation_to_seed_entities",
name=COL_SEED_ENTITIES_JSON,
fn=apply_validation_to_seed_entities,
operation=DetectionTransformOperation.APPLY_VALIDATION_TO_SEED_ENTITIES,
row=row,
)
row[COL_AUGMENTED_ENTITIES] = {"entities": []}
row = compare_transform(
"merge_and_build_candidates",
name=COL_MERGED_ENTITIES,
fn=merge_and_build_candidates,
operation=DetectionTransformOperation.MERGE_AND_BUILD_CANDIDATES,
row=row,
)
row = compare_transform(
"apply_validation_and_finalize",
name=COL_DETECTED_ENTITIES,
fn=apply_validation_and_finalize,
operation=DetectionTransformOperation.APPLY_VALIDATION_AND_FINALIZE,
row=row,
)
assert row[COL_DETECTED_ENTITIES]["entities"]
print("All parity checks passed.")
if __name__ == "__main__":
main()Output: |
|
Added a real-provider smoke/benchmark harness for this PR and ran it locally against the Brev endpoints. Endpoints (GLiNER/gpt-oss) both on brev. Interpretation:
2-record smoke: 4-record benchmark, 2 repeats: Local harnessfrom __future__ import annotations
import argparse
import json
import os
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pandas as pd
from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
from data_designer.interface.data_designer import DataDesigner
from anonymizer.config.models import DetectionModelSelection
from anonymizer.engine.constants import (
COL_AUGMENTED_ENTITIES,
COL_DETECTED_ENTITIES,
COL_FINAL_ENTITIES,
COL_MERGED_ENTITIES,
COL_RAW_DETECTED,
COL_SEED_ENTITIES,
COL_SEED_ENTITIES_JSON,
COL_SEED_VALIDATION_CANDIDATES,
COL_TEXT,
COL_VALIDATED_ENTITIES,
COL_VALIDATION_DECISIONS,
_jinja,
)
from anonymizer.engine.detection.chunked_validation import ChunkedValidationParams, make_chunked_validation_generator
from anonymizer.engine.detection.custom_columns import (
apply_validation_and_finalize,
apply_validation_to_seed_entities,
enrich_validation_decisions,
merge_and_build_candidates,
parse_detected_entities,
prepare_validation_inputs,
)
from anonymizer.engine.detection.detection_workflow import (
EntityDetectionResult,
EntityDetectionWorkflow,
_get_augment_prompt,
_get_validation_prompt,
_resolve_detection_labels,
)
from anonymizer.engine.ndd.adapter import NddAdapter
from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases
from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema
LABELS = [
"first_name",
"last_name",
"email",
"phone_number",
"city",
"organization",
"street_address",
"date",
"ssn",
]
ROWS = [
"Alice Johnson moved from Seattle to Denver on March 3, 2024. Call her at (555) 123-4567.",
"Bob from Acme Corp emailed bob.smith@example.com after visiting 221B Baker Street.",
"Maria Garcia lives in Miami and listed SSN 123-45-6789 on the onboarding form.",
"Chen Wei joined Contoso in Boston; his backup number is 415-555-0199.",
]
@dataclass(frozen=True)
class RunResult:
mode: str
elapsed_s: float
failed_records: int
dataframe: pd.DataFrame
class LegacyEntityDetectionWorkflow(EntityDetectionWorkflow):
def detect_and_validate_entities(
self,
dataframe: pd.DataFrame,
*,
model_configs: list[ModelConfig],
selected_models: DetectionModelSelection,
gliner_detection_threshold: float,
validation_max_entities_per_call: int,
validation_excerpt_window_chars: int,
entity_labels: list[str] | None = None,
data_summary: str | None = None,
preview_num_records: int | None = None,
) -> EntityDetectionResult:
labels = _resolve_detection_labels(entity_labels)
workflow_model_configs = self._inject_detector_params(
model_configs=model_configs,
selected_models=selected_models,
labels=labels,
gliner_detection_threshold=gliner_detection_threshold,
)
detection_alias = resolve_model_alias("entity_detector", selected_models)
validator_aliases = resolve_model_aliases("entity_validator", selected_models)
augmenter_alias = resolve_model_alias("entity_augmenter", selected_models)
validator_generator = make_chunked_validation_generator(validator_aliases)
validator_params = ChunkedValidationParams(
pool=list(validator_aliases),
max_entities_per_call=validation_max_entities_per_call,
excerpt_window_chars=validation_excerpt_window_chars,
prompt_template=_get_validation_prompt(data_summary=data_summary, labels=labels),
)
result = self._adapter.run_workflow(
dataframe,
model_configs=workflow_model_configs,
columns=[
LLMTextColumnConfig(
name=COL_RAW_DETECTED,
prompt=_jinja(COL_TEXT),
model_alias=detection_alias,
),
CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities),
CustomColumnConfig(
name=COL_SEED_VALIDATION_CANDIDATES,
generator_function=prepare_validation_inputs,
),
CustomColumnConfig(
name=COL_VALIDATION_DECISIONS,
generator_function=validator_generator,
generator_params=validator_params,
drop=True,
),
CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions),
CustomColumnConfig(
name=COL_SEED_ENTITIES_JSON,
generator_function=apply_validation_to_seed_entities,
),
LLMStructuredColumnConfig(
name=COL_AUGMENTED_ENTITIES,
prompt=_get_augment_prompt(
data_summary=data_summary,
labels=labels,
strict_labels=entity_labels is not None,
),
model_alias=augmenter_alias,
output_format=AugmentedEntitiesSchema,
),
CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates),
CustomColumnConfig(
name=COL_DETECTED_ENTITIES,
generator_function=apply_validation_and_finalize,
),
],
workflow_name="legacy-entity-detection",
preview_num_records=preview_num_records,
)
return EntityDetectionResult(dataframe=result.dataframe.copy(), failed_records=result.failed_records)
def model_providers(api_key_env: str) -> list[ModelProvider]:
if not os.environ.get(api_key_env):
raise RuntimeError(f"{api_key_env} must be set")
return [
ModelProvider(
name="gliner",
endpoint=GLINER_ENDPOINT,
provider_type="openai",
),
ModelProvider(
name="gptoss",
endpoint=GPTOSS_ENDPOINT,
provider_type="openai",
api_key=api_key_env,
),
]
def model_configs() -> list[ModelConfig]:
return [
ModelConfig(
alias="gliner-pii-detector",
model="fastino/gliner2-privacy-filter-PII-multi",
provider="gliner",
skip_health_check=True,
inference_parameters=ChatCompletionInferenceParams(
max_parallel_requests=4,
timeout=120,
),
),
ModelConfig(
alias="gpt-oss-120b",
model="gpt-oss-120b",
provider="gptoss",
skip_health_check=True,
inference_parameters=ChatCompletionInferenceParams(
max_parallel_requests=2,
max_tokens=4096,
temperature=0.0,
top_p=1.0,
timeout=300,
),
),
]
def selection() -> DetectionModelSelection:
return DetectionModelSelection(
entity_detector="gliner-pii-detector",
entity_validator=["gpt-oss-120b"],
entity_augmenter="gpt-oss-120b",
latent_detector="gpt-oss-120b",
)
def input_df(records: int) -> pd.DataFrame:
rows = [ROWS[i % len(ROWS)] for i in range(records)]
return pd.DataFrame({COL_TEXT: rows})
def run_workflow(
mode: str,
workflow_cls: type[EntityDetectionWorkflow],
providers: list[ModelProvider],
records: int,
) -> RunResult:
artifact_root = Path(tempfile.mkdtemp(prefix=f"anonymizer_{mode}_bench_"))
data_designer = DataDesigner(artifact_path=artifact_root, model_providers=providers)
workflow = workflow_cls(adapter=NddAdapter(data_designer))
start = time.perf_counter()
result = workflow.run(
input_df(records),
model_configs=model_configs(),
selected_models=selection(),
gliner_detection_threshold=0.5,
validation_max_entities_per_call=8,
validation_excerpt_window_chars=160,
entity_labels=LABELS,
data_summary="Short synthetic contact records containing names, locations, organizations, emails, phones, and IDs.",
tag_latent_entities=False,
compute_grouped_entities=True,
preview_num_records=records,
)
elapsed = time.perf_counter() - start
return RunResult(mode=mode, elapsed_s=elapsed, failed_records=len(result.failed_records), dataframe=result.dataframe)
def signatures(df: pd.DataFrame) -> list[list[tuple[str, str, int, int]]]:
result: list[list[tuple[str, str, int, int]]] = []
for raw in df[COL_FINAL_ENTITIES].tolist():
parsed = EntitiesSchema.from_raw(raw)
result.append(
sorted(
(entity.value, entity.label, entity.start_position, entity.end_position)
for entity in parsed.entities
)
)
return result
def row_counts(df: pd.DataFrame) -> list[int]:
return [len(row) for row in signatures(df)]
def emit_summary(run_results: list[RunResult]) -> None:
by_mode: dict[str, list[RunResult]] = {}
for result in run_results:
by_mode.setdefault(result.mode, []).append(result)
print("\nRuntime summary")
for mode, results in by_mode.items():
times = [result.elapsed_s for result in results]
print(
json.dumps(
{
"mode": mode,
"runs": len(times),
"elapsed_s": [round(t, 3) for t in times],
"mean_s": round(statistics.mean(times), 3),
"failed_records": [result.failed_records for result in results],
"entity_counts_last": row_counts(results[-1].dataframe),
},
sort_keys=True,
)
)
if {"legacy", "plugin"} <= by_mode.keys():
legacy_sig = signatures(by_mode["legacy"][-1].dataframe)
plugin_sig = signatures(by_mode["plugin"][-1].dataframe)
print("\nLast-run output comparison")
print(json.dumps({"exact_entity_signature_match": legacy_sig == plugin_sig}, sort_keys=True))
if legacy_sig != plugin_sig:
print(json.dumps({"legacy_last": legacy_sig, "plugin_last": plugin_sig}, default=str))
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--records", type=int, default=4)
parser.add_argument("--repeats", type=int, default=1)
parser.add_argument("--api-key-env", default="GPTOSS_API_KEY")
args = parser.parse_args()
providers = model_providers(args.api_key_env)
schedule: list[tuple[str, type[EntityDetectionWorkflow]]] = []
for _ in range(args.repeats):
schedule.extend(
[
("legacy", LegacyEntityDetectionWorkflow),
("plugin", EntityDetectionWorkflow),
]
)
results: list[RunResult] = []
for mode, cls in schedule:
print(f"\nRunning {mode} workflow on {args.records} records...")
result = run_workflow(mode, cls, providers, args.records)
results.append(result)
print(
json.dumps(
{
"mode": mode,
"elapsed_s": round(result.elapsed_s, 3),
"failed_records": result.failed_records,
"entity_counts": row_counts(result.dataframe),
},
sort_keys=True,
)
)
emit_summary(results)
if __name__ == "__main__":
main() |
Greptile SummaryThis PR replaces all
Confidence Score: 4/5The detection workflow behavior is unchanged — all six transform steps and the chunked validation step are mapped correctly to the new plugin column types, with required_columns, side_effect_columns, and pool metadata all matching the original custom_column_generator decorators exactly. The migration is thorough and the column metadata audit against custom_columns.py confirms parity. The untested async bridge in impl.py is a real gap — a DataDesigner engine-mode change could break that path silently — but it does not affect the current default (sync) execution path where tests do run end-to-end. src/anonymizer/engine/workflow_columns/detection/impl.py — specifically the _AsyncBridgedModelFacade async bridge path, which exercises DataDesigner internals (SyncClientUnavailableError, ensure_async_engine_loop) that are not covered by any test in this PR. Important Files Changed
|
| @property | ||
| def required_columns(self) -> list[str]: | ||
| return self._REQUIRED_COLUMNS[DetectionTransformOperation(self.operation)] | ||
|
|
||
| @property | ||
| def side_effect_columns(self) -> list[str]: | ||
| return self._SIDE_EFFECT_COLUMNS[DetectionTransformOperation(self.operation)] |
There was a problem hiding this comment.
self.operation is already a DetectionTransformOperation instance — Pydantic coerces the field to the enum type on model construction, so the DetectionTransformOperation(self.operation) wrapping here (and in side_effect_columns) is redundant. It's harmless today, but it makes the lookup slightly misleading as if the value might be a raw string at access time.
| @property | |
| def required_columns(self) -> list[str]: | |
| return self._REQUIRED_COLUMNS[DetectionTransformOperation(self.operation)] | |
| @property | |
| def side_effect_columns(self) -> list[str]: | |
| return self._SIDE_EFFECT_COLUMNS[DetectionTransformOperation(self.operation)] | |
| @property | |
| def required_columns(self) -> list[str]: | |
| return self._REQUIRED_COLUMNS[self.operation] | |
| @property | |
| def side_effect_columns(self) -> list[str]: | |
| return self._SIDE_EFFECT_COLUMNS[self.operation] |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| class _AsyncBridgedModelFacade: | ||
| def __init__(self, facade: Any) -> None: | ||
| self._facade = facade | ||
|
|
||
| def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: | ||
| from data_designer.engine.models.clients.errors import SyncClientUnavailableError | ||
|
|
||
| try: | ||
| return self._facade.generate(*args, **kwargs) | ||
| except SyncClientUnavailableError: | ||
| pass | ||
|
|
||
| try: | ||
| asyncio.get_running_loop() | ||
| except RuntimeError: | ||
| pass | ||
| else: | ||
| raise RuntimeError("model.generate() cannot be bridged from the running event loop.") | ||
|
|
||
| from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop | ||
|
|
||
| timeout_override = kwargs.get("timeout") | ||
| request_timeout = float(timeout_override) if timeout_override is not None else self._facade.request_timeout | ||
| bridge_timeout = _compute_bridge_timeout( | ||
| request_timeout=request_timeout, | ||
| max_correction_steps=int(kwargs.get("max_correction_steps", 0) or 0), | ||
| max_conversation_restarts=int(kwargs.get("max_conversation_restarts", 0) or 0), | ||
| ) | ||
| loop = ensure_async_engine_loop() | ||
| future = asyncio.run_coroutine_threadsafe(self._facade.agenerate(*args, **kwargs), loop) | ||
| try: | ||
| return future.result(timeout=bridge_timeout) | ||
| except concurrent.futures.TimeoutError as exc: | ||
| future.cancel() | ||
| from data_designer.engine.models.errors import ModelTimeoutError | ||
|
|
||
| raise ModelTimeoutError(f"model.generate() bridge timed out after {bridge_timeout:.0f}s") from exc | ||
|
|
There was a problem hiding this comment.
Async bridge path has no unit test coverage. The
SyncClientUnavailableError → run_coroutine_threadsafe bridge is only reachable when DataDesigner's async engine is active and the sync client is unavailable. None of the new tests exercise this path — test_detection_workflow_uses_plugin_transform_columns and siblings mock adapter.run_workflow before any generator executes. A future DataDesigner upgrade that flips the async engine on by default could silently regress this path. Consider adding a unit test that stubs self._facade.generate to raise SyncClientUnavailableError and verifies the facade falls through to agenerate.
Summary
data_designer.pluginsentry points.CustomColumnConfigfor parse, validation prep, chunked validation, merge, and finalize steps.plans/custom-column-plugins/, including the note that native NER transport is a separate follow-up from pluginizing custom columns.Anonymizerexport lazy so DataDesigner plugin discovery does not hit a circular import.Type of Change
Testing
make testpasses locallymake checkpasses locally (format + lint + typecheck + lock-check)Targeted validation:
.venv/bin/ruff check --fix ..venv/bin/ruff format ..venv/bin/ruff check src/anonymizer/engine/detection src/anonymizer/engine/workflow_columns tests/engine/test_detection_workflow.py tests/engine/test_chunked_validation.py.venv/bin/pytest tests/engine/test_detection_postprocess.py tests/engine/test_chunked_validation.py tests/engine/test_detection_workflow.pyanonymizer-chunked-validationandanonymizer-detection-transformload via DataDesigner registry.Documentation
make docs-buildpasses locallyNot run; this PR adds a migration plan under
plans/, not rendered docs-site content.NER Transport Note
This PR does not remove the chat-completions-compatible NER path. The detector remains an
LLMTextColumnConfig; avoiding the extra chat completions head would require a separate native detector workflow column or provider/client path.Related Issues