diff --git a/dataflow/operators/hallucination_detection/__init__.py b/dataflow/operators/hallucination_detection/__init__.py new file mode 100644 index 00000000..564475dd --- /dev/null +++ b/dataflow/operators/hallucination_detection/__init__.py @@ -0,0 +1,29 @@ +""" +Hallucination Detection Operators for DataFlow. + +This module provides operators for creating hallucination detection datasets, +including filtering by token length, injecting hallucinations, and parsing +span annotations. + +Operators: +- LongContextFilterOperator: Filter samples by token count (8K+, 12K+, etc.) +- HallucinationInjectionOperator: Inject RAGTruth-style hallucinations +- SpanAnnotationOperator: Parse tags to character positions +- HallucinationDetectionEvaluator: Evaluate hallucination detection models +""" + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .filter.long_context_filter import LongContextFilterOperator + from .generate.hallucination_injection import HallucinationInjectionOperator + from .generate.span_annotation import SpanAnnotationOperator +else: + import sys + from dataflow.utils.registry import LazyLoader, generate_import_structure_from_type_checking + + cur_path = "dataflow/operators/hallucination_detection/" + + _import_structure = generate_import_structure_from_type_checking(__file__, cur_path) + sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/hallucination_detection/", _import_structure) + diff --git a/dataflow/operators/hallucination_detection/filter/long_context_filter.py b/dataflow/operators/hallucination_detection/filter/long_context_filter.py new file mode 100644 index 00000000..290fcc02 --- /dev/null +++ b/dataflow/operators/hallucination_detection/filter/long_context_filter.py @@ -0,0 +1,193 @@ +""" +Long Context Filter Operator. + +Filters samples based on token count to create long-context evaluation datasets. +Useful for benchmarking models with extended context windows (8K+, 12K+, 16K+, etc.). +""" + +import pandas as pd +from typing import Optional, Union, List +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow import get_logger + +try: + from transformers import AutoTokenizer + HAS_TRANSFORMERS = True +except ImportError: + HAS_TRANSFORMERS = False + + +@OPERATOR_REGISTRY.register() +class LongContextFilterOperator(OperatorABC): + """Filter samples by token count for long-context evaluation. + + This operator tokenizes text fields and filters samples based on + minimum and maximum token counts. Useful for creating evaluation + datasets that test models with extended context windows. + + Example: + >>> from dataflow.operators.hallucination_detection import LongContextFilterOperator + >>> from transformers import AutoTokenizer + >>> + >>> tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") + >>> filter_op = LongContextFilterOperator( + ... tokenizer=tokenizer, + ... min_tokens=8000, + ... max_tokens=24000, + ... ) + >>> # Use in pipeline + """ + + def __init__( + self, + tokenizer: Optional["AutoTokenizer"] = None, + tokenizer_name: str = "answerdotai/ModernBERT-base", + min_tokens: int = 8000, + max_tokens: int = 32000, + text_fields: Optional[List[str]] = None, + add_token_count: bool = True, + ): + """Initialize the LongContextFilterOperator. + + Args: + tokenizer: Pre-loaded HuggingFace tokenizer. If None, loads from tokenizer_name. + tokenizer_name: HuggingFace model name to load tokenizer from. + min_tokens: Minimum token count (inclusive). + max_tokens: Maximum token count (inclusive). + text_fields: List of fields to concatenate for token counting. + Defaults to ["prompt", "answer"] or ["text"]. + add_token_count: If True, adds a 'num_tokens' column to output. + """ + self.logger = get_logger() + + if not HAS_TRANSFORMERS: + raise ImportError( + "transformers is required for LongContextFilterOperator. " + "Install with: pip install transformers" + ) + + if tokenizer is not None: + self.tokenizer = tokenizer + else: + self.logger.info(f"Loading tokenizer from {tokenizer_name}") + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + self.min_tokens = min_tokens + self.max_tokens = max_tokens + self.text_fields = text_fields or ["prompt", "answer"] + self.add_token_count = add_token_count + + @staticmethod + def get_desc(lang: str = "en") -> str: + """Returns a description of the operator's functionality.""" + if lang == "zh": + return ( + "根据token数量过滤样本的算子,用于创建长上下文评估数据集。\n\n" + "__init__参数:\n" + "- tokenizer: HuggingFace tokenizer对象,用于token计数\n" + "- tokenizer_name: tokenizer模型名称,默认'answerdotai/ModernBERT-base'\n" + "- min_tokens: 最小token数(含),默认8000\n" + "- max_tokens: 最大token数(含),默认32000\n" + "- text_fields: 需要计算token的文本字段列表,默认['prompt', 'answer']\n" + "- add_token_count: 是否添加num_tokens列,默认True\n\n" + "run参数:\n" + "- storage: DataFlow存储对象\n" + "- input_key: 输入数据的键名\n" + "- output_key: 输出数据的键名\n\n" + "输出:过滤后的DataFrame,包含符合token范围的样本。" + ) + else: + return ( + "An operator that filters samples by token count for long-context evaluation datasets.\n\n" + "__init__ Parameters:\n" + "- tokenizer: HuggingFace tokenizer object for token counting\n" + "- tokenizer_name: Tokenizer model name, default 'answerdotai/ModernBERT-base'\n" + "- min_tokens: Minimum token count (inclusive), default 8000\n" + "- max_tokens: Maximum token count (inclusive), default 32000\n" + "- text_fields: List of text fields to count tokens from, default ['prompt', 'answer']\n" + "- add_token_count: Whether to add num_tokens column, default True\n\n" + "run Parameters:\n" + "- storage: DataFlow storage object\n" + "- input_key: Key for input data\n" + "- output_key: Key for output data\n\n" + "Output: Filtered DataFrame containing samples within the token range." + ) + + def _count_tokens(self, row: pd.Series) -> int: + """Count tokens for a single row.""" + texts = [] + for field in self.text_fields: + if field in row and pd.notna(row[field]): + texts.append(str(row[field])) + + combined_text = " ".join(texts) + tokens = self.tokenizer.encode(combined_text, add_special_tokens=True) + return len(tokens) + + def run( + self, + storage: DataFlowStorage, + input_key: str = "dataframe", + output_key: str = "filtered_dataframe", + ) -> None: + """Run the filter operation. + + Args: + storage: DataFlow storage object containing the dataframe. + input_key: Key for the input dataframe in storage. + output_key: Key for the output filtered dataframe. + """ + df = storage.get(input_key) + + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected DataFrame, got {type(df)}") + + self.logger.info(f"Filtering {len(df)} samples by token count [{self.min_tokens}, {self.max_tokens}]") + + # Detect available text fields + available_fields = [f for f in self.text_fields if f in df.columns] + if not available_fields: + # Fallback to 'text' if present + if "text" in df.columns: + available_fields = ["text"] + else: + raise ValueError( + f"None of the text_fields {self.text_fields} found in dataframe. " + f"Available columns: {list(df.columns)}" + ) + + self.text_fields = available_fields + self.logger.info(f"Using text fields: {self.text_fields}") + + # Count tokens for each row + from tqdm import tqdm + tqdm.pandas(desc="Counting tokens") + df["_token_count"] = df.progress_apply(self._count_tokens, axis=1) + + # Filter by token count + mask = (df["_token_count"] >= self.min_tokens) & (df["_token_count"] <= self.max_tokens) + filtered_df = df[mask].copy() + + # Rename or drop token count column + if self.add_token_count: + filtered_df = filtered_df.rename(columns={"_token_count": "num_tokens"}) + else: + filtered_df = filtered_df.drop(columns=["_token_count"]) + + # Log statistics + self.logger.info( + f"Filtered: {len(filtered_df)}/{len(df)} samples " + f"({len(filtered_df)/len(df)*100:.1f}%) in token range [{self.min_tokens}, {self.max_tokens}]" + ) + + if len(filtered_df) > 0 and self.add_token_count: + self.logger.info( + f"Token stats: min={filtered_df['num_tokens'].min()}, " + f"max={filtered_df['num_tokens'].max()}, " + f"mean={filtered_df['num_tokens'].mean():.0f}" + ) + + storage.set(output_key, filtered_df) + diff --git a/dataflow/operators/hallucination_detection/generate/hallucination_injection.py b/dataflow/operators/hallucination_detection/generate/hallucination_injection.py new file mode 100644 index 00000000..07f6142b --- /dev/null +++ b/dataflow/operators/hallucination_detection/generate/hallucination_injection.py @@ -0,0 +1,305 @@ +""" +Hallucination Injection Operator. + +Injects RAGTruth-style hallucinations into LLM-generated answers for creating +evaluation datasets. Supports multiple hallucination types: +- Evident Conflict: Direct contradiction of source facts +- Evident Baseless Info: Fabricated facts not in source +- Subtle Baseless Info: Implied but unstated claims +- Subtle Conflict: Nuanced contradictions +""" + +import pandas as pd +import random +import re +from typing import Optional, List, Literal +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.core import OperatorABC, LLMServingABC +from dataflow.utils.storage import DataFlowStorage +from dataflow import get_logger +from tqdm import tqdm + + +# RAGTruth-style hallucination injection prompts +HALLUCINATION_PROMPTS = { + "Evident Conflict": """Modify the answer to contain a DIRECT CONTRADICTION of a fact in the reference. +Mark the contradicting part with ... tags. + +Reference excerpt: +{reference} + +Original answer: +{answer} + +Return ONLY the modified answer with tags around the contradicting part. No explanation.""", + + "Evident Baseless Info": """Add a FABRICATED fact that is NOT in the reference but sounds plausible. +Mark the fabricated part with ... tags. + +Reference excerpt: +{reference} + +Original answer: +{answer} + +Return ONLY the modified answer with tags around the fabricated part. No explanation.""", + + "Subtle Baseless Info": """Add an IMPLIED claim that goes beyond what the reference states. +Mark the implied claim with ... tags. + +Reference excerpt: +{reference} + +Original answer: +{answer} + +Return ONLY the modified answer with tags around the implied claim. No explanation.""", + + "Subtle Conflict": """Modify the answer to contain a NUANCED CONTRADICTION - something that seems consistent +but actually conflicts with the reference upon careful reading. +Mark the conflicting part with ... tags. + +Reference excerpt: +{reference} + +Original answer: +{answer} + +Return ONLY the modified answer with tags around the conflicting part. No explanation.""", +} + + +@OPERATOR_REGISTRY.register() +class HallucinationInjectionOperator(OperatorABC): + """Inject RAGTruth-style hallucinations into answers. + + This operator takes QA pairs with reference context and injects + controlled hallucinations for creating evaluation datasets. + + Example: + >>> from dataflow.operators.hallucination_detection import HallucinationInjectionOperator + >>> from dataflow.serving import LocalHostLLMAPIServing_vllm + >>> + >>> llm = LocalHostLLMAPIServing_vllm( + ... hf_model_name_or_path="Qwen/Qwen2.5-72B-Instruct", + ... vllm_server_port=8000, + ... ) + >>> injector = HallucinationInjectionOperator( + ... llm_serving=llm, + ... hallucination_ratio=0.5, + ... hallucination_types=["Evident Conflict", "Evident Baseless Info"], + ... ) + """ + + def __init__( + self, + llm_serving: LLMServingABC, + hallucination_ratio: float = 0.5, + hallucination_types: Optional[List[str]] = None, + seed: int = 42, + max_reference_chars: int = 4000, + ): + """Initialize the HallucinationInjectionOperator. + + Args: + llm_serving: LLM serving backend for generating hallucinations. + hallucination_ratio: Fraction of samples to inject hallucinations (0-1). + hallucination_types: List of hallucination types to use. + Options: "Evident Conflict", "Evident Baseless Info", + "Subtle Baseless Info", "Subtle Conflict" + seed: Random seed for reproducibility. + max_reference_chars: Maximum characters from reference to include in prompt. + """ + self.logger = get_logger() + self.llm_serving = llm_serving + self.hallucination_ratio = hallucination_ratio + self.hallucination_types = hallucination_types or [ + "Evident Conflict", + "Evident Baseless Info", + ] + self.seed = seed + self.max_reference_chars = max_reference_chars + self.rng = random.Random(seed) + + # Validate hallucination types + for hal_type in self.hallucination_types: + if hal_type not in HALLUCINATION_PROMPTS: + raise ValueError( + f"Unknown hallucination type: {hal_type}. " + f"Options: {list(HALLUCINATION_PROMPTS.keys())}" + ) + + @staticmethod + def get_desc(lang: str = "en") -> str: + """Returns a description of the operator's functionality.""" + if lang == "zh": + return ( + "向LLM生成的答案中注入RAGTruth风格幻觉的算子,用于创建幻觉检测训练数据。\n\n" + "__init__参数:\n" + "- llm_serving: LLM服务对象,用于生成带幻觉的答案\n" + "- hallucination_ratio: 注入幻觉的样本比例(0-1),默认0.5\n" + "- hallucination_types: 幻觉类型列表,可选'Evident Conflict'、'Evident Baseless Info'、'Subtle Baseless Info'、'Subtle Conflict'\n" + "- seed: 随机种子,默认42\n" + "- max_reference_chars: 参考文本最大字符数,默认4000\n\n" + "run参数:\n" + "- storage: DataFlow存储对象\n" + "- input_key: 输入数据的键名\n" + "- output_key: 输出数据的键名\n" + "- input_context_field: 上下文字段名,默认'context'\n" + "- input_answer_field: 答案字段名,默认'answer'\n\n" + "输出:DataFrame包含has_hallucination、hallucination_type、labels等字段。" + ) + else: + return ( + "An operator that injects RAGTruth-style hallucinations into LLM answers for creating detection training data.\n\n" + "__init__ Parameters:\n" + "- llm_serving: LLM serving object for generating hallucinated answers\n" + "- hallucination_ratio: Fraction of samples to inject hallucinations (0-1), default 0.5\n" + "- hallucination_types: List of hallucination types, options: 'Evident Conflict', 'Evident Baseless Info', 'Subtle Baseless Info', 'Subtle Conflict'\n" + "- seed: Random seed, default 42\n" + "- max_reference_chars: Max chars from reference context, default 4000\n\n" + "run Parameters:\n" + "- storage: DataFlow storage object\n" + "- input_key: Key for input data\n" + "- output_key: Key for output data\n" + "- input_context_field: Column name for context, default 'context'\n" + "- input_answer_field: Column name for answer, default 'answer'\n\n" + "Output: DataFrame with has_hallucination, hallucination_type, labels fields." + ) + + def _get_reference_excerpt(self, context: str) -> str: + """Get a truncated excerpt from the context for the prompt.""" + if len(context) <= self.max_reference_chars: + return context + + # Take beginning and end + half = self.max_reference_chars // 2 + return context[:half] + "\n...\n" + context[-half:] + + def _inject_hallucination( + self, + answer: str, + context: str, + hal_type: str, + ) -> Optional[str]: + """Inject a hallucination into an answer using the LLM.""" + reference = self._get_reference_excerpt(context) + prompt = HALLUCINATION_PROMPTS[hal_type].format( + reference=reference, + answer=answer, + ) + + try: + response = self.llm_serving.generate(prompt) + if isinstance(response, list): + response = response[0] + return response.strip() + except Exception as e: + self.logger.warning(f"Hallucination injection failed: {e}") + return None + + def _parse_hal_tags(self, text: str) -> List[dict]: + """Parse ... tags to extract span positions.""" + labels = [] + # Remove tags and track positions + clean_text = text + for match in re.finditer(r"(.*?)", text, re.DOTALL): + hal_text = match.group(1) + labels.append({ + "text": hal_text, + "label": "hallucinated", + }) + + # Clean the text + clean_text = re.sub(r"(.*?)", r"\1", text, flags=re.DOTALL) + + # Find positions in clean text + for label in labels: + start = clean_text.find(label["text"]) + if start >= 0: + label["start"] = start + label["end"] = start + len(label["text"]) + + return labels, clean_text + + def run( + self, + storage: DataFlowStorage, + input_key: str = "dataframe", + output_key: str = "hallucinated_dataframe", + input_context_field: str = "context", + input_answer_field: str = "answer", + ) -> None: + """Run the hallucination injection operation. + + Args: + storage: DataFlow storage object. + input_key: Key for the input dataframe. + output_key: Key for the output dataframe. + input_context_field: Column name for the reference context. + input_answer_field: Column name for the answer to modify. + """ + df = storage.get(input_key) + + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected DataFrame, got {type(df)}") + + # Validate required columns + for col in [input_context_field, input_answer_field]: + if col not in df.columns: + raise ValueError(f"Missing required column: {col}") + + n_samples = len(df) + n_to_inject = int(n_samples * self.hallucination_ratio) + inject_indices = set(self.rng.sample(range(n_samples), n_to_inject)) + + self.logger.info( + f"Injecting hallucinations into {n_to_inject}/{n_samples} samples " + f"({self.hallucination_ratio*100:.0f}%)" + ) + + results = [] + stats = {"total": 0, "injected": 0, "failed": 0, "by_type": {}} + + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Injecting hallucinations"): + result = row.to_dict() + result["has_hallucination"] = False + result["hallucination_type"] = None + result["labels"] = [] + + if idx in inject_indices: + # Select hallucination type + hal_type = self.rng.choice(self.hallucination_types) + + # Inject hallucination + modified = self._inject_hallucination( + answer=row[input_answer_field], + context=row[input_context_field], + hal_type=hal_type, + ) + + if modified and "" in modified: + labels, clean_answer = self._parse_hal_tags(modified) + result[input_answer_field] = clean_answer + result["has_hallucination"] = True + result["hallucination_type"] = hal_type + result["labels"] = labels + stats["injected"] += 1 + stats["by_type"][hal_type] = stats["by_type"].get(hal_type, 0) + 1 + else: + stats["failed"] += 1 + + stats["total"] += 1 + results.append(result) + + output_df = pd.DataFrame(results) + + # Log statistics + self.logger.info(f"Injection complete: {stats}") + self.logger.info( + f"Success rate: {stats['injected']}/{stats['injected']+stats['failed']} " + f"({stats['injected']/(stats['injected']+stats['failed']+1e-9)*100:.1f}%)" + ) + + storage.set(output_key, output_df) + diff --git a/dataflow/operators/hallucination_detection/generate/span_annotation.py b/dataflow/operators/hallucination_detection/generate/span_annotation.py new file mode 100644 index 00000000..637e19da --- /dev/null +++ b/dataflow/operators/hallucination_detection/generate/span_annotation.py @@ -0,0 +1,233 @@ +""" +Span Annotation Operator. + +Converts document-level hallucination labels to span-level annotations using +Natural Language Inference (NLI). Useful for converting datasets like HaluEval +to token-classification format. +""" + +import pandas as pd +import re +from typing import Optional, List +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.core import OperatorABC +from dataflow.utils.storage import DataFlowStorage +from dataflow import get_logger +from tqdm import tqdm + +try: + from transformers import pipeline + HAS_TRANSFORMERS = True +except ImportError: + HAS_TRANSFORMERS = False + + +@OPERATOR_REGISTRY.register() +class SpanAnnotationOperator(OperatorABC): + """Convert document-level labels to span-level using NLI. + + This operator takes answers with document-level hallucination labels + and identifies which specific sentences are hallucinated using NLI. + + Example: + >>> from dataflow.operators.hallucination_detection import SpanAnnotationOperator + >>> + >>> annotator = SpanAnnotationOperator( + ... nli_model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", + ... contradiction_threshold=0.7, + ... ) + """ + + def __init__( + self, + nli_model: str = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", + contradiction_threshold: float = 0.7, + device: str = "cuda", + batch_size: int = 8, + ): + """Initialize the SpanAnnotationOperator. + + Args: + nli_model: HuggingFace model for NLI classification. + contradiction_threshold: Threshold for labeling as contradiction. + device: Device to run the model on ("cuda" or "cpu"). + batch_size: Batch size for NLI inference. + """ + self.logger = get_logger() + + if not HAS_TRANSFORMERS: + raise ImportError( + "transformers is required for SpanAnnotationOperator. " + "Install with: pip install transformers" + ) + + self.nli_model_name = nli_model + self.contradiction_threshold = contradiction_threshold + self.device = device + self.batch_size = batch_size + self.nli_pipeline = None # Lazy loading + + def _load_nli_pipeline(self): + """Lazy load the NLI pipeline.""" + if self.nli_pipeline is None: + self.logger.info(f"Loading NLI model: {self.nli_model_name}") + self.nli_pipeline = pipeline( + "zero-shot-classification", + model=self.nli_model_name, + device=0 if self.device == "cuda" else -1, + ) + + @staticmethod + def get_desc(lang: str = "en") -> str: + """Returns a description of the operator's functionality.""" + if lang == "zh": + return ( + "使用NLI将文档级幻觉标签转换为span级标注的算子。\n\n" + "__init__参数:\n" + "- nli_model: NLI模型名称,默认'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli'\n" + "- contradiction_threshold: 矛盾判定阈值,默认0.7\n" + "- device: 运行设备'cuda'或'cpu',默认'cuda'\n" + "- batch_size: NLI推理批次大小,默认8\n\n" + "run参数:\n" + "- storage: DataFlow存储对象\n" + "- input_key: 输入数据的键名\n" + "- output_key: 输出数据的键名\n" + "- input_context_field: 上下文字段名,默认'context'\n" + "- input_answer_field: 答案字段名,默认'answer'\n" + "- input_is_hallucinated_field: 幻觉标记字段名,默认'is_hallucinated'\n\n" + "输出:DataFrame包含labels字段(含text、start、end、confidence)。" + ) + else: + return ( + "An operator that converts document-level hallucination labels to span-level using NLI.\n\n" + "__init__ Parameters:\n" + "- nli_model: NLI model name, default 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli'\n" + "- contradiction_threshold: Threshold for contradiction detection, default 0.7\n" + "- device: 'cuda' or 'cpu', default 'cuda'\n" + "- batch_size: Batch size for NLI inference, default 8\n\n" + "run Parameters:\n" + "- storage: DataFlow storage object\n" + "- input_key: Key for input data\n" + "- output_key: Key for output data\n" + "- input_context_field: Column name for context, default 'context'\n" + "- input_answer_field: Column name for answer, default 'answer'\n" + "- input_is_hallucinated_field: Column for hallucination flag, default 'is_hallucinated'\n\n" + "Output: DataFrame with labels field containing text, start, end, confidence." + ) + + def _split_sentences(self, text: str) -> List[str]: + """Split text into sentences.""" + # Simple sentence splitting + sentences = re.split(r'(?<=[.!?])\s+', text) + return [s.strip() for s in sentences if s.strip()] + + def _find_sentence_position(self, text: str, sentence: str) -> tuple: + """Find the start and end position of a sentence in text.""" + start = text.find(sentence) + if start == -1: + return None, None + end = start + len(sentence) + return start, end + + def _check_contradiction(self, premise: str, hypothesis: str) -> float: + """Check if hypothesis contradicts premise using NLI.""" + self._load_nli_pipeline() + + try: + result = self.nli_pipeline( + hypothesis, + candidate_labels=["entailment", "neutral", "contradiction"], + hypothesis_template="{}", + multi_label=False, + ) + + # Find contradiction score + for label, score in zip(result["labels"], result["scores"]): + if label == "contradiction": + return score + return 0.0 + except Exception as e: + self.logger.warning(f"NLI check failed: {e}") + return 0.0 + + def run( + self, + storage: DataFlowStorage, + input_key: str = "dataframe", + output_key: str = "annotated_dataframe", + input_context_field: str = "context", + input_answer_field: str = "answer", + input_is_hallucinated_field: str = "is_hallucinated", + ) -> None: + """Run the span annotation operation. + + Args: + storage: DataFlow storage object. + input_key: Key for the input dataframe. + output_key: Key for the output dataframe. + input_context_field: Column name for the reference context. + input_answer_field: Column name for the answer. + input_is_hallucinated_field: Column name indicating if sample is hallucinated. + """ + df = storage.get(input_key) + + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected DataFrame, got {type(df)}") + + # Validate required columns + for col in [input_context_field, input_answer_field]: + if col not in df.columns: + raise ValueError(f"Missing required column: {col}") + + self.logger.info(f"Annotating {len(df)} samples with span-level labels") + + results = [] + stats = {"total": 0, "annotated": 0, "spans_found": 0} + + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Annotating spans"): + result = row.to_dict() + result["labels"] = [] + + answer = row[input_answer_field] + context = row[input_context_field] + is_hallucinated = row.get(input_is_hallucinated_field, True) + + if is_hallucinated: + # Split answer into sentences + sentences = self._split_sentences(answer) + + for sentence in sentences: + if len(sentence) < 10: # Skip very short sentences + continue + + # Check contradiction + score = self._check_contradiction(context, sentence) + + if score >= self.contradiction_threshold: + start, end = self._find_sentence_position(answer, sentence) + if start is not None: + result["labels"].append({ + "text": sentence, + "start": start, + "end": end, + "label": "hallucinated", + "confidence": score, + }) + stats["spans_found"] += 1 + + if result["labels"]: + stats["annotated"] += 1 + + stats["total"] += 1 + results.append(result) + + output_df = pd.DataFrame(results) + + # Log statistics + self.logger.info( + f"Annotation complete: {stats['annotated']}/{stats['total']} samples annotated, " + f"{stats['spans_found']} total spans found" + ) + + storage.set(output_key, output_df) + diff --git a/dataflow/utils/storage.py b/dataflow/utils/storage.py index 50fcbe1c..26fd9cc9 100644 --- a/dataflow/utils/storage.py +++ b/dataflow/utils/storage.py @@ -404,6 +404,7 @@ def __init__( cache_type: Literal["json", "jsonl", "csv", "parquet", "pickle", None] = None ): self._data = None + self._store = {} # Key-value store for get/set operations self.cache_path = cache_path self.file_name_prefix = file_name_prefix self.cache_type = cache_type @@ -413,12 +414,33 @@ def set_data(self, data: Any): Set data to be written later. """ self._data = data + + def get(self, key: str) -> Any: + """ + Get data by key from the key-value store. + """ + return self._store.get(key, self._data) + + def set(self, key: str, data: Any): + """ + Set data by key in the key-value store. + """ + self._store[key] = data + self._data = data # Also update _data for compatibility def set_file_name_prefix(self, file_name_prefix: str): """ Set the file name prefix for cache files. """ self.file_name_prefix = file_name_prefix + + def get_keys_from_dataframe(self) -> list[str]: + """ + Get keys from the dataframe stored in the storage. + """ + if isinstance(self._data, pd.DataFrame): + return self._data.columns.tolist() + return [] def read(self, output_type: Literal["dataframe", "dict"] = "dataframe") -> Any: return self._data diff --git a/test/test_hallucination_detection.py b/test/test_hallucination_detection.py new file mode 100644 index 00000000..e6c1e3c9 --- /dev/null +++ b/test/test_hallucination_detection.py @@ -0,0 +1,165 @@ +""" +Tests for Hallucination Detection Operators. + +Run with: pytest test/test_hallucination_detection.py -v +""" + +import pytest +import pandas as pd +from unittest.mock import Mock, patch + + +class TestLongContextFilterOperator: + """Tests for LongContextFilterOperator.""" + + def test_filter_by_token_count(self): + """Test filtering samples by token count.""" + from dataflow.operators.hallucination_detection import LongContextFilterOperator + from dataflow.utils.storage import DummyStorage + + # Create mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.encode = lambda text, **kwargs: list(range(len(text.split()))) + + # Create test data + df = pd.DataFrame({ + "text": [ + " ".join(["word"] * 100), # 100 tokens + " ".join(["word"] * 500), # 500 tokens + " ".join(["word"] * 1000), # 1000 tokens + ] + }) + + # Create operator + op = LongContextFilterOperator( + tokenizer=mock_tokenizer, + min_tokens=200, + max_tokens=800, + text_fields=["text"], + ) + + # Run using DummyStorage + storage = DummyStorage() + storage.set("dataframe", df) + op.run(storage, input_key="dataframe", output_key="filtered") + + # Check result + result = storage.get("filtered") + assert len(result) == 1 # Only the 500-token sample + assert "num_tokens" in result.columns + + def test_multiple_text_fields(self): + """Test filtering with multiple text fields.""" + from dataflow.operators.hallucination_detection import LongContextFilterOperator + from dataflow.utils.storage import DummyStorage + + mock_tokenizer = Mock() + mock_tokenizer.encode = lambda text, **kwargs: list(range(len(text.split()))) + + df = pd.DataFrame({ + "prompt": [" ".join(["word"] * 100)], + "answer": [" ".join(["word"] * 50)], + }) + + op = LongContextFilterOperator( + tokenizer=mock_tokenizer, + min_tokens=100, + max_tokens=200, + text_fields=["prompt", "answer"], + ) + + storage = DummyStorage() + storage.set("dataframe", df) + op.run(storage, input_key="dataframe", output_key="filtered") + + result = storage.get("filtered") + assert len(result) == 1 + assert result["num_tokens"].iloc[0] == 150 # 100 + 50 + + +class TestHallucinationInjectionOperator: + """Tests for HallucinationInjectionOperator.""" + + def test_injection_ratio(self): + """Test that hallucination ratio is respected.""" + from dataflow.operators.hallucination_detection import HallucinationInjectionOperator + from dataflow.utils.storage import DummyStorage + + # Create mock LLM serving + mock_llm = Mock() + mock_llm.generate = Mock(return_value="The capital is Berlin.") + + df = pd.DataFrame({ + "context": ["France is in Europe."] * 10, + "answer": ["The capital is Paris."] * 10, + }) + + op = HallucinationInjectionOperator( + llm_serving=mock_llm, + hallucination_ratio=0.5, + seed=42, + ) + + storage = DummyStorage() + storage.set("dataframe", df) + op.run(storage, input_key="dataframe", output_key="output") + + result = storage.get("output") + n_hallucinated = result["has_hallucination"].sum() + + # Should be approximately 50% (±2 due to randomness) + assert 3 <= n_hallucinated <= 7 + + def test_parse_hal_tags(self): + """Test parsing of tags.""" + from dataflow.operators.hallucination_detection import HallucinationInjectionOperator + + mock_llm = Mock() + op = HallucinationInjectionOperator(llm_serving=mock_llm) + + text = "The capital is Berlin, a beautiful city." + labels, clean = op._parse_hal_tags(text) + + assert clean == "The capital is Berlin, a beautiful city." + assert len(labels) == 1 + assert labels[0]["text"] == "Berlin" + assert labels[0]["start"] == 15 + assert labels[0]["end"] == 21 + + +class TestSpanAnnotationOperator: + """Tests for SpanAnnotationOperator.""" + + def test_sentence_splitting(self): + """Test sentence splitting.""" + # Import the module to get the actual class + from dataflow.operators.hallucination_detection import SpanAnnotationOperator + + # Create instance without initializing (to avoid transformers import) + op = SpanAnnotationOperator.__new__(SpanAnnotationOperator) + op.logger = Mock() + + text = "This is sentence one. This is sentence two! Is this three?" + sentences = op._split_sentences(text) + + assert len(sentences) == 3 + assert sentences[0] == "This is sentence one." + assert sentences[1] == "This is sentence two!" + assert sentences[2] == "Is this three?" + + def test_position_finding(self): + """Test finding sentence positions.""" + from dataflow.operators.hallucination_detection import SpanAnnotationOperator + + op = SpanAnnotationOperator.__new__(SpanAnnotationOperator) + op.logger = Mock() + + text = "The quick brown fox jumps." + start, end = op._find_sentence_position(text, "brown fox") + + assert start == 10 + assert end == 19 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])