diff --git a/packages/data-designer-config/src/data_designer/config/base.py b/packages/data-designer-config/src/data_designer/config/base.py index a4e55fa27..3f02cd881 100644 --- a/packages/data-designer-config/src/data_designer/config/base.py +++ b/packages/data-designer-config/src/data_designer/config/base.py @@ -31,14 +31,19 @@ class SingleColumnConfig(ConfigBase, ABC): name: Unique name of the column to be generated. drop: If True, the column will be generated but removed from the final dataset. Useful for intermediate columns that are dependencies for other columns. + allow_resize: If True, the column is allowed to be resized during generation. column_type: Discriminator field that identifies the specific column type. Subclasses must override this field to specify the column type with a `Literal` value. """ - name: str - drop: bool = False - allow_resize: bool = False - column_type: str + name: str = Field(description="Unique name of the column to be generated") + drop: bool = Field( + default=False, description="If True, the column will be generated but removed from the final dataset" + ) + allow_resize: bool = Field( + default=False, description="If True, the column is allowed to be resized during generation" + ) + column_type: str = Field(description="Discriminator field that identifies the specific column type") @staticmethod def get_column_emoji() -> str: diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index 49dbb8311..151202902 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -56,11 +56,22 @@ class SamplerColumnConfig(SingleColumnConfig): ``` """ - sampler_type: SamplerType - params: Annotated[SamplerParamsT, Discriminator("sampler_type")] - conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {} - convert_to: str | None = None - column_type: Literal["sampler"] = "sampler" + sampler_type: SamplerType = Field( + description="Type of sampler to use (e.g., uuid, category, uniform, gaussian, person, datetime)" + ) + params: Annotated[SamplerParamsT, Discriminator("sampler_type")] = Field( + description="Parameters specific to the chosen sampler type" + ) + conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = Field( + default_factory=dict, + description="Optional dictionary for conditional parameters; keys are conditions, values are params to use when met", + ) + convert_to: str | None = Field( + default=None, description="Optional type conversion after sampling: 'float', 'int', or 'str'" + ) + column_type: Literal["sampler"] = Field( + default="sampler", description="Discriminator field, always 'sampler' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -136,14 +147,28 @@ class LLMTextColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "llm-text" for this configuration type. """ - prompt: str - model_alias: str - system_prompt: str | None = None - multi_modal_context: list[ImageContext] | None = None - tool_alias: str | None = None - with_trace: TraceType = TraceType.NONE - extract_reasoning_content: bool = False - column_type: Literal["llm-text"] = "llm-text" + prompt: str = Field( + description="Jinja2 template for the LLM prompt; can reference other columns via {{ column_name }}" + ) + model_alias: str = Field(description="Alias of the model configuration to use for generation") + system_prompt: str | None = Field( + default=None, description="Optional system prompt to set model behavior and constraints" + ) + multi_modal_context: list[ImageContext] | None = Field( + default=None, description="Optional list of ImageContext for vision model inputs" + ) + tool_alias: str | None = Field( + default=None, description="Optional alias of the tool configuration to use for MCP tool calls" + ) + with_trace: TraceType = Field( + default=TraceType.NONE, description="Trace capture mode: NONE, LAST_MESSAGE, or ALL_MESSAGES" + ) + extract_reasoning_content: bool = Field( + default=False, description="If True, capture chain-of-thought in {name}__reasoning_content column" + ) + column_type: Literal["llm-text"] = Field( + default="llm-text", description="Discriminator field, always 'llm-text' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -219,8 +244,12 @@ class LLMCodeColumnConfig(LLMTextColumnConfig): column containing the reasoning content from the final assistant response. """ - code_lang: CodeLang - column_type: Literal["llm-code"] = "llm-code" + code_lang: CodeLang = Field( + description="Target programming language or SQL dialect for code extraction from LLM response" + ) + column_type: Literal["llm-code"] = Field( + default="llm-code", description="Discriminator field, always 'llm-code' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -252,8 +281,12 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig): column containing the reasoning content from the final assistant response. """ - output_format: dict | type[BaseModel] - column_type: Literal["llm-structured"] = "llm-structured" + output_format: dict | type[BaseModel] = Field( + description="Pydantic model or JSON schema dict defining the expected structured output shape" + ) + column_type: Literal["llm-structured"] = Field( + default="llm-structured", description="Discriminator field, always 'llm-structured' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -317,8 +350,12 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig): column containing the reasoning content from the final assistant response. """ - scores: list[Score] = Field(..., min_length=1) - column_type: Literal["llm-judge"] = "llm-judge" + scores: list[Score] = Field( + ..., min_length=1, description="List of Score objects defining rubric criteria for LLM judge evaluation" + ) + column_type: Literal["llm-judge"] = Field( + default="llm-judge", description="Discriminator field, always 'llm-judge' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -341,10 +378,13 @@ class ExpressionColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "expression" for this configuration type. """ - name: str - expr: str - dtype: Literal["int", "float", "str", "bool"] = "str" - column_type: Literal["expression"] = "expression" + expr: str = Field(description="Jinja2 expression to compute the column value from other columns") + dtype: Literal["int", "float", "str", "bool"] = Field( + default="str", description="Data type for expression result: 'int', 'float', 'str', or 'bool'" + ) + column_type: Literal["expression"] = Field( + default="expression", description="Discriminator field, always 'expression' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -410,11 +450,13 @@ class ValidationColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "validation" for this configuration type. """ - target_columns: list[str] - validator_type: ValidatorType - validator_params: ValidatorParamsT + target_columns: list[str] = Field(description="List of column names to validate") + validator_type: ValidatorType = Field(description="Validation method: 'code', 'local_callable', or 'remote'") + validator_params: ValidatorParamsT = Field(description="Validator-specific parameters (e.g., CodeValidatorParams)") batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") - column_type: Literal["validation"] = "validation" + column_type: Literal["validation"] = Field( + default="validation", description="Discriminator field, always 'validation' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -441,7 +483,9 @@ class SeedDatasetColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "seed-dataset" for this configuration type. """ - column_type: Literal["seed-dataset"] = "seed-dataset" + column_type: Literal["seed-dataset"] = Field( + default="seed-dataset", description="Discriminator field, always 'seed-dataset' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -468,9 +512,11 @@ class EmbeddingColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "embedding" for this configuration type. """ - target_column: str - model_alias: str - column_type: Literal["embedding"] = "embedding" + target_column: str = Field(description="Name of the text column to generate embeddings for") + model_alias: str = Field(description="Alias of the model to use for embedding generation") + column_type: Literal["embedding"] = Field( + default="embedding", description="Discriminator field, always 'embedding' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -502,10 +548,16 @@ class ImageColumnConfig(SingleColumnConfig): column_type: Discriminator field, always "image" for this configuration type. """ - prompt: str - model_alias: str - multi_modal_context: list[ImageContext] | None = None - column_type: Literal["image"] = "image" + prompt: str = Field( + description="Jinja2 template for the image generation prompt; can reference other columns via {{ column_name }}" + ) + model_alias: str = Field(description="Alias of the model to use for image generation") + multi_modal_context: list[ImageContext] | None = Field( + default=None, description="Optional list of ImageContext for image-to-image generation inputs" + ) + column_type: Literal["image"] = Field( + default="image", description="Discriminator field, always 'image' for this configuration type" + ) @staticmethod def get_column_emoji() -> str: @@ -562,7 +614,9 @@ class CustomColumnConfig(SingleColumnConfig): default=None, description="Optional typed configuration object passed as second argument to generator function", ) - column_type: Literal["custom"] = "custom" + column_type: Literal["custom"] = Field( + default="custom", description="Discriminator field, always 'custom' for this configuration type" + ) @field_validator("generator_function") @classmethod diff --git a/packages/data-designer-config/src/data_designer/config/mcp.py b/packages/data-designer-config/src/data_designer/config/mcp.py index fe870fa86..e0683e3c7 100644 --- a/packages/data-designer-config/src/data_designer/config/mcp.py +++ b/packages/data-designer-config/src/data_designer/config/mcp.py @@ -33,10 +33,12 @@ class MCPProvider(ConfigBase): ... ) """ - provider_type: Literal["sse"] = "sse" - name: str - endpoint: str - api_key: str | None = None + provider_type: Literal["sse"] = Field( + default="sse", description="Transport type discriminator, always 'sse' for remote MCP providers" + ) + name: str = Field(description="Unique name used to reference this MCP provider") + endpoint: str = Field(description="SSE endpoint URL for connecting to the remote MCP server") + api_key: str | None = Field(default=None, description="Optional API key for authentication") class LocalStdioMCPProvider(ConfigBase): @@ -63,11 +65,15 @@ class LocalStdioMCPProvider(ConfigBase): ... ) """ - provider_type: Literal["stdio"] = "stdio" - name: str - command: str - args: list[str] = Field(default_factory=list) - env: dict[str, str] = Field(default_factory=dict) + provider_type: Literal["stdio"] = Field( + default="stdio", description="Transport type discriminator, always 'stdio' for local subprocess MCP providers" + ) + name: str = Field(description="Unique name used to reference this MCP provider") + command: str = Field(description="Executable to launch the MCP server via stdio transport") + args: list[str] = Field(default_factory=list, description="Arguments passed to the MCP server executable") + env: dict[str, str] = Field( + default_factory=dict, description="Environment variables passed to the MCP server subprocess" + ) MCPProviderT: TypeAlias = Annotated[MCPProvider | LocalStdioMCPProvider, Field(discriminator="provider_type")] @@ -102,8 +108,12 @@ class ToolConfig(ConfigBase): ... ) """ - tool_alias: str - providers: list[str] - allow_tools: list[str] | None = None - max_tool_call_turns: int = Field(default=5, ge=1) - timeout_sec: float | None = Field(default=None, gt=0) + tool_alias: str = Field(description="User-defined alias to reference this tool configuration in column configs") + providers: list[str] = Field(description="Names of the MCP providers to use for tool calls") + allow_tools: list[str] | None = Field( + default=None, description="Optional allowlist of tool names that restricts which tools are permitted" + ) + max_tool_call_turns: int = Field( + default=5, ge=1, description="Maximum number of tool-calling turns permitted in a single generation" + ) + timeout_sec: float | None = Field(default=None, gt=0, description="Timeout in seconds for MCP tool calls") diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 578b34eec..76b66ac0f 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -56,9 +56,11 @@ class DistributionType(str, Enum): class ModalityContext(ABC, BaseModel): - modality: Modality - column_name: str - data_type: ModalityDataType | None = None + modality: Modality = Field(description="The modality type for this context") + column_name: str = Field(description="Name of the column containing the modality data") + data_type: ModalityDataType | None = Field( + default=None, description="Format of the modality data ('url' or 'base64')" + ) @abstractmethod def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: ... @@ -76,8 +78,8 @@ class ImageContext(ModalityContext): image_format: Image format (required when data_type is explicitly "base64"). """ - modality: Modality = Modality.IMAGE - image_format: ImageFormat | None = None + modality: Modality = Field(default=Modality.IMAGE, description="The modality type, always 'image' for ImageContext") + image_format: ImageFormat | None = Field(default=None, description="Image format (required for base64 data)") def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: """Get the contexts for the image modality. @@ -179,8 +181,8 @@ def _validate_image_format(self) -> Self: class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]): - distribution_type: DistributionType - params: DistributionParamsT + distribution_type: DistributionType = Field(description="Type of distribution for sampling") + params: DistributionParamsT = Field(description="Parameters for the distribution") @abstractmethod def sample(self) -> float: ... @@ -194,8 +196,10 @@ class ManualDistributionParams(ConfigBase): weights: Optional list of weights for each value. If not provided, all values have equal probability. """ - values: list[float] = Field(min_length=1) - weights: list[float] | None = None + values: list[float] = Field(min_length=1, description="List of possible values to sample from") + weights: list[float] | None = Field( + default=None, description="Optional probability weights for each value; automatically normalized to sum to 1.0" + ) @model_validator(mode="after") def _normalize_weights(self) -> Self: @@ -221,8 +225,10 @@ class ManualDistribution(Distribution[ManualDistributionParams]): params: Distribution parameters (values, weights). """ - distribution_type: DistributionType | None = "manual" - params: ManualDistributionParams + distribution_type: DistributionType | None = Field( + default="manual", description="Type of distribution, always 'manual' for this class" + ) + params: ManualDistributionParams = Field(description="Manual distribution parameters (values and optional weights)") def sample(self) -> float: """Sample a value from the manual distribution. @@ -241,8 +247,8 @@ class UniformDistributionParams(ConfigBase): high: Upper bound (exclusive). """ - low: float - high: float + low: float = Field(description="Lower bound of the uniform distribution (inclusive)") + high: float = Field(description="Upper bound of the uniform distribution (exclusive)") @model_validator(mode="after") def _validate_low_lt_high(self) -> Self: @@ -262,8 +268,10 @@ class UniformDistribution(Distribution[UniformDistributionParams]): params: Distribution parameters (low, high). """ - distribution_type: DistributionType | None = "uniform" - params: UniformDistributionParams + distribution_type: DistributionType | None = Field( + default="uniform", description="Type of distribution, always 'uniform' for this class" + ) + params: UniformDistributionParams = Field(description="Uniform distribution parameters (low and high bounds)") def sample(self) -> float: """Sample a value from the uniform distribution. @@ -293,10 +301,14 @@ class BaseInferenceParams(ConfigBase, ABC): extra_body: Additional parameters to pass to the model API. """ - generation_type: GenerationType - max_parallel_requests: int = Field(default=4, ge=1) - timeout: int | None = Field(default=None, ge=1) - extra_body: dict[str, Any] | None = None + generation_type: GenerationType = Field(description="Type of generation (chat-completion, embedding, or image)") + max_parallel_requests: int = Field( + default=4, ge=1, description="Maximum number of parallel requests to the model API" + ) + timeout: int | None = Field(default=None, ge=1, description="Timeout in seconds for each request") + extra_body: dict[str, Any] | None = Field( + default=None, description="Additional parameters to pass to the model API" + ) @property def generate_kwargs(self) -> dict[str, Any]: @@ -361,10 +373,19 @@ class ChatCompletionInferenceParams(BaseInferenceParams): max_tokens: Maximum number of tokens to generate in the response. """ - generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION - temperature: float | DistributionT | None = None - top_p: float | DistributionT | None = None - max_tokens: int | None = Field(default=None, ge=1) + generation_type: Literal[GenerationType.CHAT_COMPLETION] = Field( + default=GenerationType.CHAT_COMPLETION, + description="Type of generation, always 'chat-completion' for this class", + ) + temperature: float | DistributionT | None = Field( + default=None, description="Sampling temperature (0.0-2.0); can be a fixed value or a distribution" + ) + top_p: float | DistributionT | None = Field( + default=None, description="Nucleus sampling probability (0.0-1.0); can be a fixed value or a distribution" + ) + max_tokens: int | None = Field( + default=None, ge=1, description="Maximum number of tokens to generate in the response" + ) @property def generate_kwargs(self) -> dict[str, Any]: @@ -446,9 +467,13 @@ class EmbeddingInferenceParams(BaseInferenceParams): dimensions: Number of dimensions for the embedding. """ - generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING - encoding_format: Literal["float", "base64"] = "float" - dimensions: int | None = None + generation_type: Literal[GenerationType.EMBEDDING] = Field( + default=GenerationType.EMBEDDING, description="Type of generation, always 'embedding' for this class" + ) + encoding_format: Literal["float", "base64"] = Field( + default="float", description="Format of the embedding encoding ('float' or 'base64')" + ) + dimensions: int | None = Field(default=None, description="Number of dimensions for the embedding") @property def generate_kwargs(self) -> dict[str, float | int]: @@ -489,7 +514,9 @@ class ImageInferenceParams(BaseInferenceParams): ``` """ - generation_type: Literal[GenerationType.IMAGE] = GenerationType.IMAGE + generation_type: Literal[GenerationType.IMAGE] = Field( + default=GenerationType.IMAGE, description="Type of generation, always 'image' for this class" + ) InferenceParamsT: TypeAlias = Annotated[ @@ -510,11 +537,14 @@ class ModelConfig(ConfigBase): skip_health_check: Whether to skip the health check for this model. Defaults to False. """ - alias: str - model: str - inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams) - provider: str | None = None - skip_health_check: bool = False + alias: str = Field(description="User-defined alias to reference in column configurations") + model: str = Field(description="Model identifier (e.g., from build.nvidia.com or other providers)") + inference_parameters: InferenceParamsT = Field( + default_factory=ChatCompletionInferenceParams, + description="Inference parameters for the model (temperature, top_p, max_tokens, etc.)", + ) + provider: str | None = Field(default=None, description="Optional model provider name if using custom providers") + skip_health_check: bool = Field(default=False, description="Whether to skip the health check for this model") @property def generation_type(self) -> GenerationType: @@ -551,12 +581,12 @@ class ModelProvider(ConfigBase): extra_headers: Additional headers to pass in API requests. """ - name: str - endpoint: str - provider_type: str = "openai" - api_key: str | None = None - extra_body: dict[str, Any] | None = None - extra_headers: dict[str, str] | None = None + name: str = Field(description="Name of the model provider") + endpoint: str = Field(description="API endpoint URL for the provider") + provider_type: str = Field(default="openai", description="Provider type. Determines the API format to use") + api_key: str | None = Field(default=None, description="Optional API key for authentication") + extra_body: dict[str, Any] | None = Field(default=None, description="Additional parameters to pass in API requests") + extra_headers: dict[str, str] | None = Field(default=None, description="Additional headers to pass in API requests") def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]: diff --git a/packages/data-designer-config/src/data_designer/config/processors.py b/packages/data-designer-config/src/data_designer/config/processors.py index 733dd5ab5..6ac61c0e8 100644 --- a/packages/data-designer-config/src/data_designer/config/processors.py +++ b/packages/data-designer-config/src/data_designer/config/processors.py @@ -57,7 +57,10 @@ class DropColumnsProcessorConfig(ProcessorConfig): """ column_names: list[str] = Field(description="List of column names to drop from the output dataset.") - processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS + processor_type: Literal[ProcessorType.DROP_COLUMNS] = Field( + default=ProcessorType.DROP_COLUMNS, + description="Discriminator field, always 'drop_columns' for this processor type", + ) class SchemaTransformProcessorConfig(ProcessorConfig): @@ -97,7 +100,10 @@ class SchemaTransformProcessorConfig(ProcessorConfig): References to columns "col1" and "col2" in the templates will be replaced with the actual values of the columns in the dataset. """, ) - processor_type: Literal[ProcessorType.SCHEMA_TRANSFORM] = ProcessorType.SCHEMA_TRANSFORM + processor_type: Literal[ProcessorType.SCHEMA_TRANSFORM] = Field( + default=ProcessorType.SCHEMA_TRANSFORM, + description="Discriminator field, always 'schema_transform' for this processor type", + ) @field_validator("template") def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]: diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 03b2ed297..538844add 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -35,13 +35,41 @@ class RunConfig(ConfigBase): Default is 0. """ - disable_early_shutdown: bool = False - shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0) - shutdown_error_window: int = Field(default=10, ge=0) - buffer_size: int = Field(default=1000, gt=0) - non_inference_max_parallel_workers: int = Field(default=4, ge=1) - max_conversation_restarts: int = Field(default=5, ge=0) - max_conversation_correction_steps: int = Field(default=0, ge=0) + disable_early_shutdown: bool = Field( + default=False, + description="If True, disables early-shutdown behavior; generation continues regardless of error rate", + ) + shutdown_error_rate: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Error rate threshold (0.0-1.0) that triggers early shutdown when early shutdown is enabled", + ) + shutdown_error_window: int = Field( + default=10, + ge=0, + description="Minimum number of completed tasks before error rate monitoring begins", + ) + buffer_size: int = Field( + default=1000, + gt=0, + description="Number of records to process in each batch during dataset generation", + ) + non_inference_max_parallel_workers: int = Field( + default=4, + ge=1, + description="Maximum number of worker threads used for non-inference cell-by-cell generators", + ) + max_conversation_restarts: int = Field( + default=5, + ge=0, + description="Maximum number of full conversation restarts permitted per ModelFacade.generate() call", + ) + max_conversation_correction_steps: int = Field( + default=0, + ge=0, + description="Maximum number of correction rounds permitted within a single conversation", + ) @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index 86dc2c09c..eb6470159 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from enum import Enum +from pydantic import Field from typing_extensions import TypeAlias from data_designer.config.base import ConfigBase @@ -24,7 +25,9 @@ class InequalityOperator(str, Enum): class Constraint(ConfigBase, ABC): - target_column: str + """Base class for sampler column constraints.""" + + target_column: str = Field(description="Name of the sampler column this constraint applies to") @property @abstractmethod @@ -32,8 +35,10 @@ def constraint_type(self) -> ConstraintType: ... class ScalarInequalityConstraint(Constraint): - rhs: float - operator: InequalityOperator + """Sampler constraint that compares a sampler column's generated values against a scalar threshold.""" + + rhs: float = Field(description="Scalar value to compare against") + operator: InequalityOperator = Field(description="Comparison operator (lt, le, gt, ge)") @property def constraint_type(self) -> ConstraintType: @@ -41,8 +46,10 @@ def constraint_type(self) -> ConstraintType: class ColumnInequalityConstraint(Constraint): - rhs: str - operator: InequalityOperator + """Sampler constraint that compares a sampler column's generated values against another sampler column's values.""" + + rhs: str = Field(description="Name of the other column to compare against") + operator: InequalityOperator = Field(description="Comparison operator (lt, le, gt, ge)") @property def constraint_type(self) -> ConstraintType: diff --git a/packages/data-designer-config/src/data_designer/config/sampler_params.py b/packages/data-designer-config/src/data_designer/config/sampler_params.py index aafad16db..01af4c983 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_params.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_params.py @@ -68,7 +68,9 @@ class CategorySamplerParams(ConfigBase): "Larger values will be sampled with higher probability." ), ) - sampler_type: Literal[SamplerType.CATEGORY] = SamplerType.CATEGORY + sampler_type: Literal[SamplerType.CATEGORY] = Field( + default=SamplerType.CATEGORY, description="Sampler type discriminator, always 'category' for this sampler" + ) @model_validator(mode="after") def _normalize_weights_if_needed(self) -> Self: @@ -109,7 +111,9 @@ class DatetimeSamplerParams(ConfigBase): default="D", description="Sampling units, e.g. the smallest possible time interval between samples.", ) - sampler_type: Literal[SamplerType.DATETIME] = SamplerType.DATETIME + sampler_type: Literal[SamplerType.DATETIME] = Field( + default=SamplerType.DATETIME, description="Sampler type discriminator, always 'datetime' for this sampler" + ) @field_validator("start", "end") @classmethod @@ -140,7 +144,9 @@ class SubcategorySamplerParams(ConfigBase): ..., description="Mapping from each value of parent category to a list of subcategory values.", ) - sampler_type: Literal[SamplerType.SUBCATEGORY] = SamplerType.SUBCATEGORY + sampler_type: Literal[SamplerType.SUBCATEGORY] = Field( + default=SamplerType.SUBCATEGORY, description="Sampler type discriminator, always 'subcategory' for this sampler" + ) class TimeDeltaSamplerParams(ConfigBase): @@ -192,7 +198,9 @@ class TimeDeltaSamplerParams(ConfigBase): default="D", description="Sampling units, e.g. the smallest possible time interval between samples.", ) - sampler_type: Literal[SamplerType.TIMEDELTA] = SamplerType.TIMEDELTA + sampler_type: Literal[SamplerType.TIMEDELTA] = Field( + default=SamplerType.TIMEDELTA, description="Sampler type discriminator, always 'timedelta' for this sampler" + ) @model_validator(mode="after") def _validate_min_less_than_max(self) -> Self: @@ -225,7 +233,9 @@ class UUIDSamplerParams(ConfigBase): default=False, description="If true, all letters in the UUID will be capitalized.", ) - sampler_type: Literal[SamplerType.UUID] = SamplerType.UUID + sampler_type: Literal[SamplerType.UUID] = Field( + default=SamplerType.UUID, description="Sampler type discriminator, always 'uuid' for this sampler" + ) @property def last_index(self) -> int: @@ -264,7 +274,9 @@ class ScipySamplerParams(ConfigBase): decimal_places: int | None = Field( default=None, description="Number of decimal places to round the sampled values to." ) - sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY + sampler_type: Literal[SamplerType.SCIPY] = Field( + default=SamplerType.SCIPY, description="Sampler type discriminator, always 'scipy' for this sampler" + ) class BinomialSamplerParams(ConfigBase): @@ -281,7 +293,9 @@ class BinomialSamplerParams(ConfigBase): n: int = Field(..., description="Number of trials.") p: float = Field(..., description="Probability of success on each trial.", ge=0.0, le=1.0) - sampler_type: Literal[SamplerType.BINOMIAL] = SamplerType.BINOMIAL + sampler_type: Literal[SamplerType.BINOMIAL] = Field( + default=SamplerType.BINOMIAL, description="Sampler type discriminator, always 'binomial' for this sampler" + ) class BernoulliSamplerParams(ConfigBase): @@ -297,7 +311,9 @@ class BernoulliSamplerParams(ConfigBase): """ p: float = Field(..., description="Probability of success.", ge=0.0, le=1.0) - sampler_type: Literal[SamplerType.BERNOULLI] = SamplerType.BERNOULLI + sampler_type: Literal[SamplerType.BERNOULLI] = Field( + default=SamplerType.BERNOULLI, description="Sampler type discriminator, always 'bernoulli' for this sampler" + ) class BernoulliMixtureSamplerParams(ConfigBase): @@ -337,7 +353,10 @@ class BernoulliMixtureSamplerParams(ConfigBase): ..., description="Parameters of the scipy.stats distribution given in `dist_name`.", ) - sampler_type: Literal[SamplerType.BERNOULLI_MIXTURE] = SamplerType.BERNOULLI_MIXTURE + sampler_type: Literal[SamplerType.BERNOULLI_MIXTURE] = Field( + default=SamplerType.BERNOULLI_MIXTURE, + description="Sampler type discriminator, always 'bernoulli_mixture' for this sampler", + ) class GaussianSamplerParams(ConfigBase): @@ -361,7 +380,9 @@ class GaussianSamplerParams(ConfigBase): decimal_places: int | None = Field( default=None, description="Number of decimal places to round the sampled values to." ) - sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN + sampler_type: Literal[SamplerType.GAUSSIAN] = Field( + default=SamplerType.GAUSSIAN, description="Sampler type discriminator, always 'gaussian' for this sampler" + ) class PoissonSamplerParams(ConfigBase): @@ -381,7 +402,9 @@ class PoissonSamplerParams(ConfigBase): """ mean: float = Field(..., description="Mean number of events in a fixed interval.") - sampler_type: Literal[SamplerType.POISSON] = SamplerType.POISSON + sampler_type: Literal[SamplerType.POISSON] = Field( + default=SamplerType.POISSON, description="Sampler type discriminator, always 'poisson' for this sampler" + ) class UniformSamplerParams(ConfigBase): @@ -403,7 +426,9 @@ class UniformSamplerParams(ConfigBase): decimal_places: int | None = Field( default=None, description="Number of decimal places to round the sampled values to." ) - sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM + sampler_type: Literal[SamplerType.UNIFORM] = Field( + default=SamplerType.UNIFORM, description="Sampler type discriminator, always 'uniform' for this sampler" + ) ######################################### @@ -481,7 +506,9 @@ class PersonSamplerParams(ConfigBase): default=False, description="If True, then append synthetic persona columns to each generated person.", ) - sampler_type: Literal[SamplerType.PERSON] = SamplerType.PERSON + sampler_type: Literal[SamplerType.PERSON] = Field( + default=SamplerType.PERSON, description="Sampler type discriminator, always 'person' for this sampler" + ) @property def generator_kwargs(self) -> list[str]: @@ -564,7 +591,10 @@ class PersonFromFakerSamplerParams(ConfigBase): min_length=2, max_length=2, ) - sampler_type: Literal[SamplerType.PERSON_FROM_FAKER] = SamplerType.PERSON_FROM_FAKER + sampler_type: Literal[SamplerType.PERSON_FROM_FAKER] = Field( + default=SamplerType.PERSON_FROM_FAKER, + description="Sampler type discriminator, always 'person_from_faker' for this sampler", + ) @property def generator_kwargs(self) -> list[str]: diff --git a/packages/data-designer-config/src/data_designer/config/seed.py b/packages/data-designer-config/src/data_designer/config/seed.py index bdd9dae29..c791f954a 100644 --- a/packages/data-designer-config/src/data_designer/config/seed.py +++ b/packages/data-designer-config/src/data_designer/config/seed.py @@ -111,6 +111,11 @@ class SeedConfig(ConfigBase): ) """ - source: SeedSourceT - sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED - selection_strategy: IndexRange | PartitionBlock | None = None + source: SeedSourceT = Field(description="A SeedSource defining where the seed data exists") + sampling_strategy: SamplingStrategy = Field( + default=SamplingStrategy.ORDERED, + description="Strategy for how to sample rows: ORDERED (sequential) or SHUFFLE (random)", + ) + selection_strategy: IndexRange | PartitionBlock | None = Field( + default=None, description="Optional strategy to select a subset of the dataset (IndexRange or PartitionBlock)" + ) diff --git a/packages/data-designer-config/src/data_designer/config/seed_source.py b/packages/data-designer-config/src/data_designer/config/seed_source.py index c9f31eb46..7244a6e41 100644 --- a/packages/data-designer-config/src/data_designer/config/seed_source.py +++ b/packages/data-designer-config/src/data_designer/config/seed_source.py @@ -26,13 +26,17 @@ class SeedSource(BaseModel, ABC): This serves as a discriminated union discriminator. """ - seed_type: str + seed_type: str = Field(description="Discriminator field identifying the seed source type") class LocalFileSeedSource(SeedSource): - seed_type: Literal["local"] = "local" + """Seed source that reads data from a local file (e.g., Parquet, CSV, JSONL).""" - path: str + seed_type: Literal["local"] = Field( + default="local", description="Seed source type discriminator, always 'local' for local file sources" + ) + + path: str = Field(description="Path to the local seed dataset file") @field_validator("path", mode="after") def validate_path(cls, v: str) -> str: @@ -53,7 +57,11 @@ def from_dataframe(cls, df: pd.DataFrame, path: str) -> Self: class HuggingFaceSeedSource(SeedSource): - seed_type: Literal["hf"] = "hf" + """Seed source that reads data from a HuggingFace dataset repository.""" + + seed_type: Literal["hf"] = Field( + default="hf", description="Seed source type discriminator, always 'hf' for HuggingFace sources" + ) path: str = Field( ..., diff --git a/packages/data-designer/src/data_designer/cli/commands/agent_helpers/__init__.py b/packages/data-designer/src/data_designer/cli/commands/agent_helpers/__init__.py new file mode 100644 index 000000000..f1ea03ddb --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/commands/agent_helpers/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations diff --git a/packages/data-designer/src/data_designer/cli/commands/agent_helpers/inspect.py b/packages/data-designer/src/data_designer/cli/commands/agent_helpers/inspect.py new file mode 100644 index 000000000..24794b97c --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/commands/agent_helpers/inspect.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import typer + +from data_designer.cli.controllers.introspection_controller import IntrospectionController + +inspect_app = typer.Typer( + name="inspect", + help="Inspect detailed schemas for configuration objects and the Python API.", + no_args_is_help=True, +) + + +@inspect_app.command(name="column") +def columns_command( + type_name: str = typer.Argument(help="Type name (e.g. 'llm-text', 'expression'), or 'all'."), +) -> None: + """Show schema for a column config type.""" + IntrospectionController().show_columns(type_name) + + +@inspect_app.command(name="sampler") +def samplers_command( + type_name: str = typer.Argument(help="Type name (e.g. 'category', 'uniform'), or 'all'."), +) -> None: + """Show schema for a sampler params type.""" + IntrospectionController().show_samplers(type_name) + + +@inspect_app.command(name="validator") +def validators_command( + type_name: str = typer.Argument(help="Type name (e.g. 'code', 'python'), or 'all'."), +) -> None: + """Show schema for a validator params type.""" + IntrospectionController().show_validators(type_name) + + +@inspect_app.command(name="processor") +def processors_command( + type_name: str = typer.Argument(help="Type name (e.g. 'drop_columns'), or 'all'."), +) -> None: + """Show schema for a processor config type.""" + IntrospectionController().show_processors(type_name) + + +@inspect_app.command(name="sampler-constraints") +def constraints_command() -> None: + """Show constraint schemas for sampler columns.""" + IntrospectionController().show_sampler_constraints() + + +@inspect_app.command(name="config-builder") +def config_builder_command() -> None: + """Show DataDesignerConfigBuilder method signatures and docstrings.""" + IntrospectionController().show_builder() diff --git a/packages/data-designer/src/data_designer/cli/commands/agent_helpers/list.py b/packages/data-designer/src/data_designer/cli/commands/agent_helpers/list.py new file mode 100644 index 000000000..813f6cc21 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/commands/agent_helpers/list.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import typer + +from data_designer.cli.controllers.list_controller import ListController +from data_designer.config.utils.constants import DATA_DESIGNER_HOME + +list_app = typer.Typer( + name="list", + help="List available types, model aliases, and persona datasets.", + no_args_is_help=True, +) + + +@list_app.command(name="model-aliases") +def model_aliases_command() -> None: + """List configured model aliases and backing models. Needed for model_alias on LLM columns.""" + ListController(DATA_DESIGNER_HOME).list_model_aliases() + + +@list_app.command(name="persona-datasets") +def persona_datasets_command() -> None: + """List Nemotron-Persona datasets and install status.""" + ListController(DATA_DESIGNER_HOME).list_persona_datasets() + + +@list_app.command(name="columns") +def column_types_command() -> None: + """List column type names and config classes.""" + ListController(DATA_DESIGNER_HOME).list_column_types() + + +@list_app.command(name="samplers") +def sampler_types_command() -> None: + """List sampler type names and params classes.""" + ListController(DATA_DESIGNER_HOME).list_sampler_types() + + +@list_app.command(name="validators") +def validator_types_command() -> None: + """List validator type names and params classes.""" + ListController(DATA_DESIGNER_HOME).list_validator_types() + + +@list_app.command(name="processors") +def processor_types_command() -> None: + """List processor type names and config classes.""" + ListController(DATA_DESIGNER_HOME).list_processor_types() diff --git a/packages/data-designer/src/data_designer/cli/controllers/__init__.py b/packages/data-designer/src/data_designer/cli/controllers/__init__.py index 3d2894b76..5f59c1bba 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/__init__.py +++ b/packages/data-designer/src/data_designer/cli/controllers/__init__.py @@ -2,3 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations + +from data_designer.cli.controllers.download_controller import DownloadController +from data_designer.cli.controllers.generation_controller import GenerationController +from data_designer.cli.controllers.introspection_controller import IntrospectionController +from data_designer.cli.controllers.list_controller import ListController +from data_designer.cli.controllers.model_controller import ModelController +from data_designer.cli.controllers.provider_controller import ProviderController + +__all__ = [ + "DownloadController", + "GenerationController", + "IntrospectionController", + "ListController", + "ModelController", + "ProviderController", +] diff --git a/packages/data-designer/src/data_designer/cli/controllers/introspection_controller.py b/packages/data-designer/src/data_designer/cli/controllers/introspection_controller.py new file mode 100644 index 000000000..3f0a1e058 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/controllers/introspection_controller.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass + +import typer + +from data_designer.cli.services.introspection.discovery import ( + discover_column_configs, + discover_constraint_types, + discover_processor_configs, + discover_sampler_types, + discover_validator_types, +) +from data_designer.cli.services.introspection.formatters import format_method_info_text, format_type_list_text +from data_designer.cli.services.introspection.method_inspector import inspect_class_methods +from data_designer.cli.services.introspection.pydantic_inspector import format_model_text, get_brief_description +from data_designer.config.config_builder import DataDesignerConfigBuilder + + +@dataclass(frozen=True) +class _TypedCommandSpec: + """Configuration for typed introspection commands.""" + + discover_items: Callable[[], dict[str, type]] + type_key: str + type_label: str + class_label: str + header_title: str + case_insensitive: bool = False + + +_CONFIG_IMPORT = "import data_designer.config as dd" + + +class IntrospectionController: + """Controller for introspect CLI commands. + + Orchestrates discovery, inspection, formatting, and output for all + introspect subcommands. + """ + + _TYPED_COMMAND_SPECS: dict[str, _TypedCommandSpec] = { + "columns": _TypedCommandSpec( + discover_items=discover_column_configs, + type_key="column_type", + type_label="column_type", + class_label="config_class", + header_title="Data Designer Column Types Reference", + case_insensitive=True, + ), + "samplers": _TypedCommandSpec( + discover_items=discover_sampler_types, + type_key="sampler_type", + type_label="sampler_type", + class_label="params_class", + header_title="Data Designer Sampler Types Reference", + case_insensitive=True, + ), + "validators": _TypedCommandSpec( + discover_items=discover_validator_types, + type_key="validator_type", + type_label="validator_type", + class_label="params_class", + header_title="Data Designer Validator Types Reference", + case_insensitive=True, + ), + "processors": _TypedCommandSpec( + discover_items=discover_processor_configs, + type_key="processor_type", + type_label="processor_type", + class_label="config_class", + header_title="Data Designer Processor Types Reference", + case_insensitive=True, + ), + } + + def _emit_import_hint(self, import_stmt: str, access: str | None = None) -> None: + """Print a one-line import hint.""" + line = f"# {import_stmt}" + if access: + line += f" \u2192 {access}" + typer.echo(line) + typer.echo("") + + def show_columns(self, type_name: str | None) -> None: + """Show column configuration types.""" + self._show_typed_command(command_name="columns", type_name=type_name) + + def show_samplers(self, type_name: str | None) -> None: + """Show sampler types and their param classes.""" + self._show_typed_command(command_name="samplers", type_name=type_name) + + def show_validators(self, type_name: str | None) -> None: + """Show validator types and their param classes.""" + self._show_typed_command(command_name="validators", type_name=type_name) + + def show_processors(self, type_name: str | None) -> None: + """Show processor types and their config classes.""" + self._show_typed_command(command_name="processors", type_name=type_name) + + def show_builder(self) -> None: + """Show DataDesignerConfigBuilder method signatures and docs.""" + self._emit_import_hint(_CONFIG_IMPORT, "dd.DataDesignerConfigBuilder") + methods = inspect_class_methods(DataDesignerConfigBuilder) + typer.echo(format_method_info_text(methods, class_name="DataDesignerConfigBuilder")) + + def show_sampler_constraints(self) -> None: + """Show sampler constraint types.""" + self._emit_import_hint(_CONFIG_IMPORT) + items = discover_constraint_types() + self._show_all_schemas(items, "Data Designer Constraint Types Reference") + + def _show_typed_command(self, command_name: str, type_name: str | None) -> None: + """Resolve a typed-command spec and render it.""" + spec = self._TYPED_COMMAND_SPECS[command_name] + items = spec.discover_items() + + if type_name is None: + self._emit_import_hint(_CONFIG_IMPORT) + typer.echo(format_type_list_text(items, spec.type_label, spec.class_label)) + return + + self._show_typed_items( + items=items, + type_name=type_name, + type_key=spec.type_key, + header_title=spec.header_title, + case_insensitive=spec.case_insensitive, + ) + + def _show_typed_items( + self, + items: dict[str, type], + type_name: str, + type_key: str, + header_title: str, + case_insensitive: bool = False, + ) -> None: + """Shared logic for type-based commands (columns, samplers, validators, processors).""" + if type_name.lower() == "all": + self._show_all_typed(items, type_key, header_title) + return + + canonical_value: str | None = None + cls: type | None = None + if case_insensitive: + matched = {k.lower(): (k, v) for k, v in items.items()}.get(type_name.lower()) + if matched is not None: + canonical_value, cls = matched + else: + if type_name in items: + canonical_value = type_name + cls = items[type_name] + + if canonical_value is None or cls is None: + available = ", ".join(sorted(items.keys())) + typer.echo(f"Error: Unknown {type_key} '{type_name}'", err=True) + typer.echo(f"Available types: {available}", err=True) + raise typer.Exit(code=1) + + self._emit_import_hint(_CONFIG_IMPORT, f"dd.{cls.__name__}") + typer.echo(format_model_text(cls, type_key=type_key, type_value=canonical_value)) + + def _show_all_typed( + self, + items: dict[str, type], + type_key: str, + header_title: str, + ) -> None: + """Show all types for a typed command.""" + self._emit_import_hint(_CONFIG_IMPORT, "dd.") + sorted_types = sorted(items.keys()) + + seen_schemas: set[str] = set() + lines = [f"# {header_title}", f"# {len(sorted_types)} types discovered from data_designer.config", ""] + for type_value in sorted_types: + cls = items[type_value] + lines.append(format_model_text(cls, type_key=type_key, type_value=type_value, seen_schemas=seen_schemas)) + lines.append("") + typer.echo("\n".join(lines)) + + def _show_all_schemas(self, items: dict[str, type], header_title: str) -> None: + """Show all schemas for simple discovery commands (e.g. constraints).""" + seen_schemas: set[str] = set() + lines = [f"# {header_title}", f"# {len(items)} types", ""] + for name in sorted(items.keys()): + cls = items[name] + if hasattr(cls, "model_fields"): + lines.append(format_model_text(cls, seen_schemas=seen_schemas)) + else: + lines.append(f"{cls.__name__}:") + if cls.__doc__: + lines.append(f" description: {get_brief_description(cls)}") + if hasattr(cls, "__members__"): + members = [str(m.value) for m in cls] + lines.append(f" values: [{', '.join(members)}]") + lines.append("") + typer.echo("\n".join(lines)) diff --git a/packages/data-designer/src/data_designer/cli/controllers/list_controller.py b/packages/data-designer/src/data_designer/cli/controllers/list_controller.py new file mode 100644 index 000000000..98a376c0e --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/controllers/list_controller.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path + +import typer + +from data_designer.cli.repositories.model_repository import ModelRepository +from data_designer.cli.repositories.persona_repository import PersonaRepository +from data_designer.cli.repositories.provider_repository import ProviderRepository +from data_designer.cli.services.download_service import DownloadService +from data_designer.cli.services.introspection.discovery import ( + discover_column_configs, + discover_processor_configs, + discover_sampler_types, + discover_validator_types, +) +from data_designer.config.default_model_settings import get_providers_with_missing_api_keys + +_IMPORT_HINT = "# import data_designer.config as dd" + + +class ListController: + """Controller for listing valid configuration values.""" + + def __init__(self, config_dir: Path) -> None: + self._config_dir = config_dir + self._model_repository = ModelRepository(config_dir) + self._provider_repository = ProviderRepository(config_dir) + self._persona_repository = PersonaRepository() + self._download_service = DownloadService(config_dir, self._persona_repository) + + def list_model_aliases(self) -> None: + """List configured model aliases. + + Only shows aliases whose backing provider has a valid API key. + """ + provider_registry = self._provider_repository.load() + + if not provider_registry or not provider_registry.providers: + typer.echo("No model providers configured. Run `data-designer config models` to configure your models.") + return + + missing_key_providers = get_providers_with_missing_api_keys(provider_registry.providers) + valid_provider_names = {p.name for p in provider_registry.providers} - {p.name for p in missing_key_providers} + + if not valid_provider_names: + typer.echo( + "No model providers are configured with valid API keys. " + "Run `data-designer config models` to configure your models." + ) + return + + default_provider = provider_registry.default or provider_registry.providers[0].name + + model_registry = self._model_repository.load() + configs = model_registry.model_configs if model_registry else [] + + if not configs: + typer.echo("No model aliases configured.") + typer.echo("Run `data-designer config models` to add models.") + return + + filtered = [mc for mc in configs if (mc.provider or default_provider) in valid_provider_names] + + if not filtered: + typer.echo( + "All configured model aliases use providers without valid API keys. " + "Run `data-designer config models` to configure your models." + ) + return + + c1, c2, c3 = "model_alias", "model", "provider" + w1 = max(len(c1), max(len(mc.alias) for mc in filtered)) + w2 = max(len(c2), max(len(mc.model) for mc in filtered)) + w3 = max(len(c3), max(len(mc.provider or "default") for mc in filtered)) + typer.echo(f"{c1:<{w1}} {c2:<{w2}} {c3}") + typer.echo(f"{'-' * w1} {'-' * w2} {'-' * w3}") + for mc in filtered: + typer.echo(f"{mc.alias:<{w1}} {mc.model:<{w2}} {mc.provider or 'default'}") + + if len(filtered) < len(configs): + typer.echo(f"\n({len(configs) - len(filtered)} model alias(es) hidden — providers missing API keys)") + + def list_persona_datasets(self) -> None: + """List persona datasets available for PersonSamplerParams.""" + managed_locales = self._persona_repository.list_all() + if not managed_locales: + typer.echo("No persona datasets found.") + return + + entries: list[dict[str, str | bool]] = [] + for locale in managed_locales: + installed = self._download_service.is_locale_downloaded(locale.code) + entries.append({"locale": locale.code, "installed": installed}) + + typer.echo(_IMPORT_HINT) + typer.echo("") + col1 = "locale" + col2 = "status" + max_width = max(len(col1), max(len(str(entry["locale"])) for entry in entries)) + typer.echo(f"{col1:<{max_width}} {col2}") + typer.echo(f"{'-' * max_width} {'-' * len('not installed')}") + for entry in entries: + status = "installed" if entry["installed"] else "not installed" + typer.echo(f"{str(entry['locale']):<{max_width}} {status}") + typer.echo("") + typer.echo("Use the PersonSamplerParams locale parameter to select a dataset.") + typer.echo("Run `data-designer download personas --locale ` to install a dataset.") + + def _print_type_table( + self, + items: dict[str, type], + col1: str, + col2: str, + inspect_command: str, + ) -> None: + """Print a two-column table of discovered types with an inspect tip.""" + if not items: + typer.echo("No items found.") + return + + sorted_types = sorted(items.keys()) + max_width = max(len(col1), max(len(t) for t in sorted_types)) + + typer.echo(_IMPORT_HINT) + typer.echo("") + typer.echo(f"{col1:<{max_width}} {col2}") + typer.echo(f"{'-' * max_width} {'-' * max(len(items[t].__name__) for t in sorted_types)}") + for t in sorted_types: + typer.echo(f"{t:<{max_width}} {items[t].__name__}") + typer.echo("") + typer.echo(f"Run `data-designer inspect {inspect_command}` to see that type's full schema.") + + def list_column_types(self) -> None: + """List available column configuration types.""" + self._print_type_table(discover_column_configs(), "column_type", "config_class", "column ") + + def list_sampler_types(self) -> None: + """List available sampler types.""" + self._print_type_table(discover_sampler_types(), "sampler_type", "params_class", "sampler ") + + def list_validator_types(self) -> None: + """List available validator types.""" + self._print_type_table( + discover_validator_types(), "validator_type", "params_class", "validator " + ) + + def list_processor_types(self) -> None: + """List available processor types.""" + self._print_type_table( + discover_processor_configs(), "processor_type", "config_class", "processor " + ) diff --git a/packages/data-designer/src/data_designer/cli/main.py b/packages/data-designer/src/data_designer/cli/main.py index a45276c4e..aac5d316a 100644 --- a/packages/data-designer/src/data_designer/cli/main.py +++ b/packages/data-designer/src/data_designer/cli/main.py @@ -12,7 +12,7 @@ # Initialize Typer app with custom configuration app = typer.Typer( name="data-designer", - help="Data Designer CLI - Configure model providers and models for synthetic data generation", + help="Data Designer CLI for humans and agents.", cls=create_lazy_typer_group( { "preview": { @@ -98,9 +98,96 @@ no_args_is_help=True, ) +# Create list command group +list_app = typer.Typer( + name="list", + help="List available types, model aliases, and persona datasets.", + cls=create_lazy_typer_group( + { + "model-aliases": { + "module": f"{_CMD}.agent_helpers.list", + "attr": "model_aliases_command", + "help": "List configured model aliases and backing models", + }, + "persona-datasets": { + "module": f"{_CMD}.agent_helpers.list", + "attr": "persona_datasets_command", + "help": "List Nemotron-Persona datasets and install status", + }, + "columns": { + "module": f"{_CMD}.agent_helpers.list", + "attr": "column_types_command", + "help": "List column type names and config classes", + }, + "samplers": { + "module": f"{_CMD}.agent_helpers.list", + "attr": "sampler_types_command", + "help": "List sampler type names and params classes", + }, + "validators": { + "module": f"{_CMD}.agent_helpers.list", + "attr": "validator_types_command", + "help": "List validator type names and params classes", + }, + "processors": { + "module": f"{_CMD}.agent_helpers.list", + "attr": "processor_types_command", + "help": "List processor type names and config classes", + }, + } + ), + no_args_is_help=True, +) + +# Create inspect command group +inspect_app = typer.Typer( + name="inspect", + help="Inspect detailed schemas for configuration objects and the Python API.", + cls=create_lazy_typer_group( + { + "column": { + "module": f"{_CMD}.agent_helpers.inspect", + "attr": "columns_command", + "help": "Show schema for a column config type", + }, + "sampler": { + "module": f"{_CMD}.agent_helpers.inspect", + "attr": "samplers_command", + "help": "Show schema for a sampler params type", + }, + "validator": { + "module": f"{_CMD}.agent_helpers.inspect", + "attr": "validators_command", + "help": "Show schema for a validator params type", + }, + "processor": { + "module": f"{_CMD}.agent_helpers.inspect", + "attr": "processors_command", + "help": "Show schema for a processor config type", + }, + "sampler-constraints": { + "module": f"{_CMD}.agent_helpers.inspect", + "attr": "constraints_command", + "help": "Show constraint schemas for sampler columns", + }, + "config-builder": { + "module": f"{_CMD}.agent_helpers.inspect", + "attr": "config_builder_command", + "help": "Show DataDesignerConfigBuilder method signatures and docstrings", + }, + } + ), + no_args_is_help=True, +) + # Add setup command groups -app.add_typer(config_app, name="config", rich_help_panel="Setup") -app.add_typer(download_app, name="download", rich_help_panel="Setup") +app.add_typer(config_app, name="config", rich_help_panel="Setup Commands") +app.add_typer(download_app, name="download", rich_help_panel="Setup Commands") + +# Add agent command groups +title_agent_helpers = "Agent-Helper Commands" +app.add_typer(list_app, name="list", rich_help_panel=title_agent_helpers) +app.add_typer(inspect_app, name="inspect", rich_help_panel=title_agent_helpers) def main() -> None: diff --git a/packages/data-designer/src/data_designer/cli/services/introspection/__init__.py b/packages/data-designer/src/data_designer/cli/services/introspection/__init__.py new file mode 100644 index 000000000..4396183e5 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/services/introspection/__init__.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.cli.services.introspection.discovery import ( + discover_column_configs, + discover_constraint_types, + discover_processor_configs, + discover_sampler_types, + discover_validator_types, +) +from data_designer.cli.services.introspection.formatters import ( + format_method_info_text, + format_type_list_text, +) +from data_designer.cli.services.introspection.method_inspector import ( + MethodInfo, + ParamInfo, + inspect_class_methods, +) +from data_designer.cli.services.introspection.pydantic_inspector import ( + format_model_text, + format_type, + get_brief_description, +) + +__all__ = [ + "discover_column_configs", + "discover_constraint_types", + "discover_processor_configs", + "discover_sampler_types", + "discover_validator_types", + "format_method_info_text", + "format_model_text", + "format_type_list_text", + "format_type", + "get_brief_description", + "inspect_class_methods", + "MethodInfo", + "ParamInfo", +] diff --git a/packages/data-designer/src/data_designer/cli/services/introspection/discovery.py b/packages/data-designer/src/data_designer/cli/services/introspection/discovery.py new file mode 100644 index 000000000..663efccf6 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/services/introspection/discovery.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import inspect +from enum import Enum +from typing import Any, Literal, get_args, get_origin + +import data_designer.config as dd + + +def _extract_literal_discriminator_value(annotation: Any) -> str | None: + """Extract the first literal discriminator value from a type annotation. + + Supports ``Literal["value"]`` and ``Literal[SomeEnum.MEMBER]``. + Returns ``None`` when the annotation is not a literal discriminator. + """ + if get_origin(annotation) is not Literal: + return None + + args = get_args(annotation) + if not args: + return None + + value = args[0] + if isinstance(value, Enum): + return str(value.value) + return str(value) + + +def _discover_configs_by_discriminator( + class_name_suffix: str, + discriminator_field: str, + exclude_class_names: set[str] | None = None, +) -> dict[str, type]: + """Discover config classes whose discriminator field is a Literal value. + + Args: + class_name_suffix: Class-name suffix to select candidate classes. + discriminator_field: Pydantic field name containing the discriminator. + exclude_class_names: Optional set of class names to skip. + + Returns: + Dict mapping discriminator values to config classes. + """ + excluded = exclude_class_names or set() + discovered: dict[str, type] = {} + + for name in dir(dd): + if name in excluded or not name.endswith(class_name_suffix): + continue + + obj = getattr(dd, name) + if not (inspect.isclass(obj) and hasattr(obj, "model_fields")): + continue + if discriminator_field not in obj.model_fields: + continue + + annotation = obj.model_fields[discriminator_field].annotation + discriminator_value = _extract_literal_discriminator_value(annotation) + if discriminator_value is not None: + discovered[discriminator_value] = obj + + return discovered + + +def _discover_params_by_discriminator( + params_class_suffix: str, + discriminator_field: str, + enum_name: str, +) -> dict[str, type]: + """Discover params classes keyed by their literal discriminator value. + + Args: + params_class_suffix: Class-name suffix to select params classes. + discriminator_field: Field name that stores the literal discriminator. + enum_name: Enum class name to use for fallback name-matching. + + Returns: + Dict mapping discriminator values to params classes. + """ + discovered: dict[str, type] = {} + normalized_name_map: dict[str, type] = {} + + for name in dir(dd): + if not name.endswith(params_class_suffix): + continue + + obj = getattr(dd, name) + if not (inspect.isclass(obj) and hasattr(obj, "model_fields")): + continue + + if discriminator_field in obj.model_fields: + annotation = obj.model_fields[discriminator_field].annotation + discriminator_value = _extract_literal_discriminator_value(annotation) + if discriminator_value is not None: + discovered[discriminator_value] = obj + continue + + normalized_name = name.removesuffix(params_class_suffix).replace("_", "").lower() + normalized_name_map[normalized_name] = obj + + enum_cls = getattr(dd, enum_name, None) + if enum_cls is None or not (inspect.isclass(enum_cls) and issubclass(enum_cls, Enum)): + return discovered + + for member in enum_cls: + value = str(member.value) + if value in discovered: + continue + normalized_value = value.replace("_", "").lower() + params_cls = normalized_name_map.get(normalized_value) + if params_cls is not None: + discovered[value] = params_cls + + return discovered + + +def discover_column_configs() -> dict[str, type]: + """Dynamically discover all ColumnConfig classes from data_designer.config. + + Returns: + Dict mapping column_type literal values (e.g., 'llm-text') to their config classes. + """ + return _discover_configs_by_discriminator( + class_name_suffix="ColumnConfig", + discriminator_field="column_type", + ) + + +def discover_sampler_types() -> dict[str, type]: + """Dynamically discover sampler types and params classes from data_designer.config. + + Returns: + Dict mapping sampler type names (e.g., 'category') to their params classes. + """ + return _discover_params_by_discriminator( + params_class_suffix="SamplerParams", + discriminator_field="sampler_type", + enum_name="SamplerType", + ) + + +def discover_validator_types() -> dict[str, type]: + """Dynamically discover validator types and params classes from data_designer.config. + + Returns: + Dict mapping validator type names to their params classes. + """ + return _discover_params_by_discriminator( + params_class_suffix="ValidatorParams", + discriminator_field="validator_type", + enum_name="ValidatorType", + ) + + +def discover_processor_configs() -> dict[str, type]: + """Dynamically discover all ProcessorConfig classes from data_designer.config. + + Returns: + Dict mapping processor_type values to their config classes. + """ + return _discover_configs_by_discriminator( + class_name_suffix="ProcessorConfig", + discriminator_field="processor_type", + exclude_class_names={"ProcessorConfig"}, + ) + + +def _discover_by_modules(*module_suffixes: str) -> dict[str, type]: + """Discover config types by filtering _LAZY_IMPORTS on source-module suffix. + + Args: + module_suffixes: One or more module suffixes to match against + (e.g., ``"models"``, ``"seed"``). + + Returns: + Dict mapping class/object names to their resolved types. + """ + lazy_imports: dict[str, tuple[str, str]] = getattr(dd, "_LAZY_IMPORTS", {}) + prefix = "data_designer.config." + result: dict[str, type] = {} + for name, (module_path, _attr) in lazy_imports.items(): + suffix = module_path.removeprefix(prefix) if module_path.startswith(prefix) else module_path + if suffix in module_suffixes: + obj = getattr(dd, name, None) + if obj is not None: + result[name] = obj + return result + + +def discover_constraint_types() -> dict[str, type]: + """Return constraint-related classes from data_designer.config. + + Returns: + Dict mapping class names to their types. + """ + return _discover_by_modules("sampler_constraints") diff --git a/packages/data-designer/src/data_designer/cli/services/introspection/formatters.py b/packages/data-designer/src/data_designer/cli/services/introspection/formatters.py new file mode 100644 index 000000000..e2c7f8c44 --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/services/introspection/formatters.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.cli.services.introspection.method_inspector import MethodInfo, ParamInfo + +_MIN_CLASS_COL_WIDTH = 25 + + +def _format_param_text(param: ParamInfo, indent: int) -> str: + """Format a single method parameter as a text line.""" + pad = " " * indent + line = f"{pad}{param.name}: {param.type_str}" + if param.default is not None: + line += f" = {param.default}" + if param.description: + line += f" \u2014 {param.description}" + return line + + +def format_method_info_text(methods: list[MethodInfo], class_name: str | None = None) -> str: + """Format a list of MethodInfo as readable text with signatures and parameter details.""" + lines: list[str] = [] + if class_name: + lines.append(f"{class_name} Methods:") + lines.append("") + + for method in methods: + lines.append(f" {method.signature}") + if method.description: + lines.append(f" {method.description}") + if method.parameters: + lines.append(" Parameters:") + for param in method.parameters: + lines.append(_format_param_text(param, indent=6)) + lines.append("") + + return "\n".join(lines).rstrip() + + +def format_type_list_text(items: dict[str, type], type_label: str, class_label: str) -> str: + """Format a summary table of type->class mappings, matching the existing print_list_table style.""" + sorted_items = sorted(items.items()) + if not sorted_items: + return f"{type_label} {class_label}\n(no items)" + + type_width = max(len(type_value) for type_value, _ in sorted_items) + type_width = max(type_width, len(type_label)) + + lines: list[str] = [] + lines.append(f"{type_label:<{type_width}} {class_label}") + lines.append(f"{'-' * type_width} {'-' * max(len(class_label), _MIN_CLASS_COL_WIDTH)}") + + for type_value, cls in sorted_items: + lines.append(f"{type_value:<{type_width}} {cls.__name__}") + + return "\n".join(lines) diff --git a/packages/data-designer/src/data_designer/cli/services/introspection/method_inspector.py b/packages/data-designer/src/data_designer/cli/services/introspection/method_inspector.py new file mode 100644 index 000000000..1ce2d5c4c --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/services/introspection/method_inspector.py @@ -0,0 +1,271 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import inspect +import re +from dataclasses import dataclass, field + + +@dataclass +class ParamInfo: + name: str + type_str: str + default: str | None + description: str + + +@dataclass +class MethodInfo: + name: str + signature: str + description: str + return_type: str + parameters: list[ParamInfo] = field(default_factory=list) + + +_DEFAULT_INIT_DOCSTRING = "Initialize self. See help(type(self)) for accurate signature." + + +def _parse_google_docstring_args(docstring: str | None) -> dict[str, str]: + """Parse Args section from a Google-style docstring. + + Returns: + Dict mapping parameter names to their descriptions. + """ + if not docstring: + return {} + + lines = docstring.split("\n") + result: dict[str, str] = {} + in_args_section = False + current_param: str | None = None + current_desc_lines: list[str] = [] + args_indent: int | None = None + + section_pattern = re.compile(r"^(\s*)(Args|Returns|Raises|Yields|Note|Notes|Example|Examples|Attributes)\s*:") + + for line in lines: + if re.match(r"^\s*Args\s*:\s*$", line): + in_args_section = True + args_indent = len(line) - len(line.lstrip()) + continue + + if not in_args_section: + continue + + if not line.strip(): + if current_param is not None: + current_desc_lines.append("") + continue + + match = section_pattern.match(line) + if match and match.group(2) != "Args": + section_indent = len(line) - len(line.lstrip()) + if args_indent is not None and section_indent <= args_indent: + break + + line_indent = len(line) - len(line.lstrip()) + stripped = line.strip() + + param_match = re.match(r"^(\*{0,2}\w+)\s*(?:\(.+?\))?\s*:\s*(.*)$", stripped) + if param_match and args_indent is not None and line_indent > args_indent: + if current_param is not None: + result[current_param] = _join_desc_lines(current_desc_lines) + current_param = param_match.group(1) + current_desc_lines = [param_match.group(2).strip()] + elif current_param is not None: + if args_indent is not None and line_indent <= args_indent: + break + current_desc_lines.append(stripped) + + if current_param is not None: + result[current_param] = _join_desc_lines(current_desc_lines) + + return result + + +def _join_desc_lines(lines: list[str]) -> str: + """Join description lines, collapsing whitespace and stripping trailing blanks.""" + return " ".join(part for part in lines if part) + + +def _format_annotation(annotation: type | str) -> str: + """Format a type annotation to a readable string.""" + if annotation is inspect.Parameter.empty: + return "Any" + + if isinstance(annotation, str): + return annotation + + if hasattr(annotation, "__name__"): + return annotation.__name__ + + return str(annotation).replace("typing.", "").replace("typing_extensions.", "") + + +def _format_signature(method_name: str, sig: inspect.Signature) -> str: + """Format a method signature as a readable string, skipping 'self'.""" + params: list[str] = [] + seen_keyword_only = False + has_var_positional = any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()) + + for param in sig.parameters.values(): + if param.name == "self": + continue + + if param.kind == inspect.Parameter.KEYWORD_ONLY and not seen_keyword_only and not has_var_positional: + seen_keyword_only = True + params.append("*") + + type_str = _format_annotation(param.annotation) + default_str = "" + if param.default is not inspect.Parameter.empty: + default_str = ( + f" = {param.default!r}" if not isinstance(param.default, type) else f" = {param.default.__name__}" + ) + + if param.kind == inspect.Parameter.VAR_POSITIONAL: + params.append(f"*{param.name}: {type_str}") + elif param.kind == inspect.Parameter.VAR_KEYWORD: + params.append(f"**{param.name}") + else: + params.append(f"{param.name}: {type_str}{default_str}") + + return_type = _format_return_type(sig) + params_str = ", ".join(params) + + return f"{method_name}({params_str}) -> {return_type}" + + +def _format_return_type(sig: inspect.Signature) -> str: + """Extract and format the return type from a signature.""" + if sig.return_annotation is inspect.Parameter.empty: + return "None" + + formatted = _format_annotation(sig.return_annotation) + if formatted == "Self": + return "Self" + + return formatted + + +def _get_first_docstring_line(docstring: str | None) -> str: + """Extract the first non-empty line from a docstring as the description.""" + if not docstring: + return "" + for line in docstring.strip().split("\n"): + stripped = line.strip() + if stripped: + return stripped + return "" + + +def _build_param_info(sig: inspect.Signature, docstring_args: dict[str, str]) -> list[ParamInfo]: + """Build ParamInfo list from a signature and parsed docstring args.""" + params: list[ParamInfo] = [] + for param in sig.parameters.values(): + if param.name == "self": + continue + if param.kind == inspect.Parameter.VAR_KEYWORD: + name = f"**{param.name}" + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + name = f"*{param.name}" + else: + name = param.name + + type_str = _format_annotation(param.annotation) + default: str | None = None + if param.default is not inspect.Parameter.empty: + default = repr(param.default) if not isinstance(param.default, type) else param.default.__name__ + + raw_name = param.name + description = docstring_args.get(raw_name, "") + if not description: + description = docstring_args.get(f"**{raw_name}", "") + if not description: + description = docstring_args.get(f"*{raw_name}", "") + + params.append(ParamInfo(name=name, type_str=type_str, default=default, description=description)) + + return params + + +def _is_dunder(name: str) -> bool: + """Check if a method name is a dunder method (excluding __init__).""" + return name.startswith("__") and name.endswith("__") and name != "__init__" + + +def _is_private(name: str) -> bool: + """Check if a method name is private (starts with underscore, not dunder).""" + return name.startswith("_") and not (name.startswith("__") and name.endswith("__")) + + +def _is_default_init_docstring(docstring: str | None) -> bool: + """Check if a docstring is the unhelpful default __init__ docstring.""" + if not docstring: + return False + normalized = " ".join(docstring.strip().split()) + return normalized == _DEFAULT_INIT_DOCSTRING + + +def inspect_class_methods(cls: type, include_private: bool = False) -> list[MethodInfo]: + """Introspect public methods of a class using inspect.signature() and docstring parsing. + + Detects regular methods, classmethods, and handles __init__ docstring fallback + to the class docstring when the default is unhelpful. + + Args: + cls: The class to introspect. + include_private: If True, include methods starting with underscore. + + Returns: + List of MethodInfo objects for each method. + """ + methods: list[MethodInfo] = [] + + # inspect.isfunction finds regular methods; inspect.ismethod finds classmethods + seen: set[str] = set() + candidates: list[tuple[str, object]] = [] + candidates.extend(inspect.getmembers(cls, predicate=inspect.isfunction)) + candidates.extend(inspect.getmembers(cls, predicate=inspect.ismethod)) + + for name, method in candidates: + if name in seen: + continue + seen.add(name) + + if _is_dunder(name): + continue + if _is_private(name) and not include_private: + continue + + try: + sig = inspect.signature(method) + except (ValueError, TypeError): + continue + + docstring = inspect.getdoc(method) + + if name == "__init__" and _is_default_init_docstring(docstring): + docstring = inspect.getdoc(cls) or "" + + docstring_args = _parse_google_docstring_args(docstring) + + signature_str = _format_signature(name, sig) + description = _get_first_docstring_line(docstring) + return_type = _format_return_type(sig) + parameters = _build_param_info(sig, docstring_args) + + methods.append( + MethodInfo( + name=name, + signature=signature_str, + description=description, + return_type=return_type, + parameters=parameters, + ) + ) + + return methods diff --git a/packages/data-designer/src/data_designer/cli/services/introspection/pydantic_inspector.py b/packages/data-designer/src/data_designer/cli/services/introspection/pydantic_inspector.py new file mode 100644 index 000000000..bb1ec7f5f --- /dev/null +++ b/packages/data-designer/src/data_designer/cli/services/introspection/pydantic_inspector.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import re +import types +import typing +from enum import Enum +from typing import Any, get_args, get_origin + +from pydantic import BaseModel +from pydantic_core import PydanticUndefined + +_NO_DESCRIPTION = "No description available." + + +def _is_basemodel_subclass(cls: Any) -> bool: + """Return True if cls is a concrete BaseModel subclass (not BaseModel itself).""" + return isinstance(cls, type) and issubclass(cls, BaseModel) and cls is not BaseModel + + +def _is_enum_subclass(cls: Any) -> bool: + """Return True if cls is an Enum subclass (not Enum itself).""" + return isinstance(cls, type) and issubclass(cls, Enum) and cls is not Enum + + +def _extract_enum_class(annotation: Any) -> type | None: + """Unwrap a type annotation to find an Enum class, if present. + + Handles X, X | None, Annotated[X, ...]. + Returns the Enum class or None. + """ + if annotation is None: + return None + + # Unwrap Annotated[X, ...] + if get_origin(annotation) is typing.Annotated: + annotation = get_args(annotation)[0] + + if _is_enum_subclass(annotation): + return annotation + + origin = get_origin(annotation) + if origin is typing.Union or origin is types.UnionType: + for arg in get_args(annotation): + if arg is type(None): + continue + if _is_enum_subclass(arg): + return arg + + return None + + +def _extract_nested_basemodel(annotation: Any) -> type | None: + """Unwrap a type annotation to find a single nested BaseModel subclass. + + Handles: X, list[X], X | None, list[X] | None, dict[K, V], Annotated[X, ...]. + Returns None for unions of 2+ BaseModel subclasses (discriminated unions), + primitives, enums, or BaseModel itself. + """ + if annotation is None: + return None + + # Unwrap Annotated[X, ...] + if get_origin(annotation) is typing.Annotated: + annotation = get_args(annotation)[0] + + if _is_basemodel_subclass(annotation): + return annotation + + origin = get_origin(annotation) + + # list[X] -> check X + if origin is list: + args = get_args(annotation) + if args and _is_basemodel_subclass(args[0]): + return args[0] + return None + + # dict[K, V] -> check V + if origin is dict: + args = get_args(annotation) + if len(args) >= 2 and _is_basemodel_subclass(args[1]): + return args[1] + return None + + # Union: X | None, list[X] | None, or discriminated unions + if origin is typing.Union or origin is types.UnionType: + non_none_args = [a for a in get_args(annotation) if a is not type(None)] + basemodel_classes = [m for a in non_none_args if (m := _extract_nested_basemodel(a)) is not None] + if len(basemodel_classes) == 1: + return basemodel_classes[0] + return None + + return None + + +def _unwrap_annotated_discriminator(annotation: Any) -> Any: + """Strip Annotated wrapper containing a Discriminator.""" + if get_origin(annotation) is not typing.Annotated: + return annotation + args = get_args(annotation) + if len(args) >= 2 and any("Discriminator" in str(a) for a in args[1:]): + return args[0] + return annotation + + +def format_type(annotation: Any) -> str: + """Format a type annotation for readable display. + + Strips module prefixes and simplifies complex types. + """ + annotation = _unwrap_annotated_discriminator(annotation) + type_str = str(annotation) + + # Remove module prefixes + type_str = re.sub(r"data_designer\.config\.\w+\.", "", type_str) + type_str = re.sub(r"pydantic\.main\.", "", type_str) + type_str = re.sub(r"typing\.", "", type_str) + + # Clean up enum members used inside Literal or other contexts: -> 'value' + type_str = re.sub(r"<\w+\.\w+: '([^']+)'>", r"'\1'", type_str) + + # Clean up enum types BEFORE other replacements: -> EnumName + type_str = re.sub(r"", r"\1", type_str) + + # Clean up class types: -> str + type_str = re.sub(r"", r"\1", type_str) + + type_str = type_str.replace("NoneType", "None") + + if "Literal[" in type_str: + match = re.search(r"Literal\[([^\]]+)\]", type_str) + if match: + type_str = f"Literal[{match.group(1)}]" + + return type_str + + +def get_brief_description(cls: type) -> str: + """Extract first line from class docstring.""" + if cls.__doc__: + for line in cls.__doc__.strip().split("\n"): + stripped = line.strip() + if stripped: + return stripped + return _NO_DESCRIPTION + + +def _extract_constraints(field_info: Any) -> dict[str, Any] | None: + """Extract numeric/string constraints from a Pydantic FieldInfo's metadata.""" + constraint_keys = {"ge", "le", "gt", "lt", "min_length", "max_length"} + constraints: dict[str, Any] = {} + for meta in getattr(field_info, "metadata", []): + for key in constraint_keys: + val = getattr(meta, key, None) + if val is not None: + constraints[key] = val + return constraints or None + + +def _default_to_json(value: Any) -> Any: + """Convert a Pydantic default value to a JSON-serializable value. + + Returns the value unchanged if it is already JSON-serializable (bool, int, float, + str, None, list, dict with JSON-serializable values). Enum members are converted + to their .value. Other types are returned as a string representation for stability. + """ + if value is None: + return None + if isinstance(value, Enum): + return value.value + if isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, list): + return [_default_to_json(item) for item in value] + if isinstance(value, dict): + return {k: _default_to_json(v) for k, v in value.items()} + return repr(value) + + +def _format_field( + field_name: str, + field_info: Any, + indent: int, + seen_schemas: set[str] | None, + seen_types: set[type], + max_depth: int, + depth: int, +) -> list[str]: + """Format a single Pydantic field as YAML-style text lines, recursing into nested schemas.""" + pad = " " * indent + lines: list[str] = [] + + type_str = format_type(field_info.annotation) + description: str = field_info.description or "" + required: bool = field_info.is_required() + + header = f"{pad}{field_name}: {type_str}" + if not required: + if field_info.default_factory is not None: + factory_name = getattr(field_info.default_factory, "__name__", repr(field_info.default_factory)) + header += f" = {factory_name}()" + elif field_info.default is not PydanticUndefined: + header += f" = {_default_to_json(field_info.default)!r}" + if required: + header += " [required]" + lines.append(header) + + if description: + lines.append(f"{pad} description: {description}") + + enum_cls = _extract_enum_class(field_info.annotation) + if enum_cls is not None: + enum_values = [str(member.value) for member in enum_cls] + lines.append(f"{pad} values: [{', '.join(enum_values)}]") + + constraints = _extract_constraints(field_info) + if constraints: + constraint_parts = [f"{k}={v}" for k, v in constraints.items()] + lines.append(f"{pad} constraints: {', '.join(constraint_parts)}") + + nested_cls = _extract_nested_basemodel(field_info.annotation) + if nested_cls is not None and nested_cls not in seen_types and depth < max_depth: + schema_key = f"{nested_cls.__module__}.{nested_cls.__qualname__}" + schema_name = nested_cls.__name__ + if seen_schemas is not None and schema_key in seen_schemas: + lines.append(f"{pad} schema: (see {schema_name} above)") + else: + if seen_schemas is not None: + seen_schemas.add(schema_key) + lines.append(f"{pad} schema ({schema_name}):") + next_seen = seen_types | {nested_cls} + nested_model_fields: dict[str, Any] = getattr(nested_cls, "model_fields", {}) + for nested_name, nested_info in nested_model_fields.items(): + lines.extend( + _format_field( + field_name=nested_name, + field_info=nested_info, + indent=indent + 4, + seen_schemas=seen_schemas, + seen_types=next_seen, + max_depth=max_depth, + depth=depth + 1, + ) + ) + + return lines + + +def format_model_text( + cls: type, + type_key: str | None = None, + type_value: str | None = None, + indent: int = 0, + seen_schemas: set[str] | None = None, + max_depth: int = 3, + seen_types: set[type] | None = None, + depth: int = 0, +) -> str: + """Format a Pydantic model as YAML-style text for agent context. + + Args: + cls: The Pydantic model class to format. + type_key: Optional discriminator key name (e.g., "column_type"). + type_value: Optional discriminator value (e.g., "llm-text"). + indent: Base indentation level. + seen_schemas: Set of schema refs already rendered (mutated for cross-model dedup). + max_depth: Maximum recursion depth for nested models. + seen_types: Set of types already rendered (prevents infinite recursion). + depth: Current recursion depth. + """ + if seen_types is None: + seen_types = set() + + pad = " " * indent + lines: list[str] = [] + lines.append(f"{pad}{cls.__name__}:") + if type_key and type_value: + lines.append(f"{pad} {type_key}: {type_value}") + lines.append(f"{pad} description: {get_brief_description(cls)}") + lines.append(f"{pad} fields:") + + model_fields: dict[str, Any] = getattr(cls, "model_fields", {}) + for field_name, field_info in model_fields.items(): + lines.extend( + _format_field( + field_name=field_name, + field_info=field_info, + indent=indent + 4, + seen_schemas=seen_schemas, + seen_types=seen_types, + max_depth=max_depth, + depth=depth, + ) + ) + + return "\n".join(lines) diff --git a/packages/data-designer/tests/cli/commands/__init__.py b/packages/data-designer/tests/cli/commands/__init__.py new file mode 100644 index 000000000..e5725ea5a --- /dev/null +++ b/packages/data-designer/tests/cli/commands/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer/tests/cli/commands/agent_helpers/__init__.py b/packages/data-designer/tests/cli/commands/agent_helpers/__init__.py new file mode 100644 index 000000000..f1ea03ddb --- /dev/null +++ b/packages/data-designer/tests/cli/commands/agent_helpers/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations diff --git a/packages/data-designer/tests/cli/commands/agent_helpers/test_introspection_commands.py b/packages/data-designer/tests/cli/commands/agent_helpers/test_introspection_commands.py new file mode 100644 index 000000000..bdaca7c25 --- /dev/null +++ b/packages/data-designer/tests/cli/commands/agent_helpers/test_introspection_commands.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typer.testing import CliRunner + +from data_designer.cli.main import app + +runner = CliRunner() + + +# --------------------------------------------------------------------------- +# help +# --------------------------------------------------------------------------- + + +def test_inspect_help() -> None: + result = runner.invoke(app, ["inspect", "--help"]) + assert result.exit_code == 0 + assert "column" in result.output + + +# --------------------------------------------------------------------------- +# columns +# --------------------------------------------------------------------------- + + +def test_columns_no_arg_fails() -> None: + result = runner.invoke(app, ["inspect", "column"]) + assert result.exit_code != 0 + + +def test_columns_specific_type() -> None: + result = runner.invoke(app, ["inspect", "column", "llm-text"]) + assert result.exit_code == 0 + assert "LLMTextColumnConfig" in result.output + + +def test_columns_nonexistent_exits_with_error() -> None: + result = runner.invoke(app, ["inspect", "column", "nonexistent"]) + assert result.exit_code == 1 + + +# --------------------------------------------------------------------------- +# samplers +# --------------------------------------------------------------------------- + + +def test_samplers_specific() -> None: + result = runner.invoke(app, ["inspect", "sampler", "category"]) + assert result.exit_code == 0 + assert "sampler_type: category" in result.output + + +def test_samplers_all_case_insensitive() -> None: + result = runner.invoke(app, ["inspect", "sampler", "ALL"]) + assert result.exit_code == 0 + assert "Data Designer Sampler Types Reference" in result.output + assert "sampler_type: category" in result.output + + +def test_samplers_no_arg_fails() -> None: + result = runner.invoke(app, ["inspect", "sampler"]) + assert result.exit_code != 0 + + +# --------------------------------------------------------------------------- +# validators +# --------------------------------------------------------------------------- + + +def test_validators_no_arg_fails() -> None: + result = runner.invoke(app, ["inspect", "validator"]) + assert result.exit_code != 0 + + +def test_validators_specific() -> None: + result = runner.invoke(app, ["inspect", "validator", "code"]) + assert result.exit_code == 0 + assert "validator_type: code" in result.output + + +def test_validators_all_case_insensitive() -> None: + result = runner.invoke(app, ["inspect", "validator", "ALL"]) + assert result.exit_code == 0 + assert "Data Designer Validator Types Reference" in result.output + assert "validator_type: code" in result.output + + +# --------------------------------------------------------------------------- +# processors +# --------------------------------------------------------------------------- + + +def test_processors_no_arg_fails() -> None: + result = runner.invoke(app, ["inspect", "processor"]) + assert result.exit_code != 0 + + +def test_processors_specific_type() -> None: + result = runner.invoke(app, ["inspect", "processor", "drop_columns"]) + assert result.exit_code == 0 + assert "DropColumnsProcessorConfig" in result.output + + +def test_processors_all() -> None: + result = runner.invoke(app, ["inspect", "processor", "all"]) + assert result.exit_code == 0 + assert "Data Designer Processor Types Reference" in result.output + + +# --------------------------------------------------------------------------- +# config-builder +# --------------------------------------------------------------------------- + + +def test_config_builder() -> None: + result = runner.invoke(app, ["inspect", "config-builder"]) + assert result.exit_code == 0 + assert "add_column" in result.output + assert "DataDesignerConfigBuilder" in result.output + assert "Parameters:" in result.output + + +# --------------------------------------------------------------------------- +# constraints +# --------------------------------------------------------------------------- + + +def test_constraints() -> None: + result = runner.invoke(app, ["inspect", "sampler-constraints"]) + assert result.exit_code == 0 + output = result.output + assert "ScalarInequalityConstraint" in output or "InequalityOperator" in output + + +# --------------------------------------------------------------------------- +# import hints +# --------------------------------------------------------------------------- + + +def test_import_hint_shown_in_text_output() -> None: + result = runner.invoke(app, ["inspect", "column", "llm-text"]) + assert result.exit_code == 0 + assert "import data_designer.config as dd" in result.output + assert "dd.LLMTextColumnConfig" in result.output + + +# --------------------------------------------------------------------------- +# list +# --------------------------------------------------------------------------- + + +def test_list_help() -> None: + result = runner.invoke(app, ["list", "--help"]) + assert result.exit_code == 0 + for subcmd in ("model-aliases", "persona-datasets", "columns", "samplers", "validators", "processors"): + assert subcmd in result.output + + +def test_list_model_aliases() -> None: + result = runner.invoke(app, ["list", "model-aliases"]) + assert result.exit_code == 0 + + +def test_list_persona_datasets() -> None: + result = runner.invoke(app, ["list", "persona-datasets"]) + assert result.exit_code == 0 + assert "locale" in result.output + + +def test_list_column_types() -> None: + result = runner.invoke(app, ["list", "columns"]) + assert result.exit_code == 0 + assert "llm-text" in result.output + + +def test_list_sampler_types() -> None: + result = runner.invoke(app, ["list", "samplers"]) + assert result.exit_code == 0 + assert "category" in result.output + + +def test_list_validator_types() -> None: + result = runner.invoke(app, ["list", "validators"]) + assert result.exit_code == 0 + assert "code" in result.output + + +def test_list_processor_types() -> None: + result = runner.invoke(app, ["list", "processors"]) + assert result.exit_code == 0 diff --git a/packages/data-designer/tests/cli/commands/agent_helpers/test_list_command.py b/packages/data-designer/tests/cli/commands/agent_helpers/test_list_command.py new file mode 100644 index 000000000..0b97a0847 --- /dev/null +++ b/packages/data-designer/tests/cli/commands/agent_helpers/test_list_command.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from data_designer.cli.commands.agent_helpers.list import ( + column_types_command, + model_aliases_command, + persona_datasets_command, + processor_types_command, + sampler_types_command, + validator_types_command, +) +from data_designer.cli.controllers.list_controller import ListController +from data_designer.config.utils.constants import DATA_DESIGNER_HOME + +_PATCH_TARGET = "data_designer.cli.commands.agent_helpers.list.ListController" + + +# --------------------------------------------------------------------------- +# model-aliases +# --------------------------------------------------------------------------- + + +@patch(_PATCH_TARGET) +def test_model_aliases_delegates_text(mock_cls: MagicMock) -> None: + mock_ctrl = MagicMock(spec=ListController) + mock_cls.return_value = mock_ctrl + + model_aliases_command() + + mock_cls.assert_called_once_with(DATA_DESIGNER_HOME) + mock_ctrl.list_model_aliases.assert_called_once_with() + + +# --------------------------------------------------------------------------- +# persona-datasets +# --------------------------------------------------------------------------- + + +@patch(_PATCH_TARGET) +def test_persona_datasets_delegates(mock_cls: MagicMock) -> None: + mock_ctrl = MagicMock(spec=ListController) + mock_cls.return_value = mock_ctrl + + persona_datasets_command() + + mock_cls.assert_called_once_with(DATA_DESIGNER_HOME) + mock_ctrl.list_persona_datasets.assert_called_once_with() + + +# --------------------------------------------------------------------------- +# columns +# --------------------------------------------------------------------------- + + +@patch(_PATCH_TARGET) +def test_column_types_delegates(mock_cls: MagicMock) -> None: + mock_ctrl = MagicMock(spec=ListController) + mock_cls.return_value = mock_ctrl + + column_types_command() + + mock_cls.assert_called_once_with(DATA_DESIGNER_HOME) + mock_ctrl.list_column_types.assert_called_once_with() + + +# --------------------------------------------------------------------------- +# samplers +# --------------------------------------------------------------------------- + + +@patch(_PATCH_TARGET) +def test_sampler_types_delegates(mock_cls: MagicMock) -> None: + mock_ctrl = MagicMock(spec=ListController) + mock_cls.return_value = mock_ctrl + + sampler_types_command() + + mock_cls.assert_called_once_with(DATA_DESIGNER_HOME) + mock_ctrl.list_sampler_types.assert_called_once_with() + + +# --------------------------------------------------------------------------- +# validators +# --------------------------------------------------------------------------- + + +@patch(_PATCH_TARGET) +def test_validator_types_delegates(mock_cls: MagicMock) -> None: + mock_ctrl = MagicMock(spec=ListController) + mock_cls.return_value = mock_ctrl + + validator_types_command() + + mock_cls.assert_called_once_with(DATA_DESIGNER_HOME) + mock_ctrl.list_validator_types.assert_called_once_with() + + +# --------------------------------------------------------------------------- +# processors +# --------------------------------------------------------------------------- + + +@patch(_PATCH_TARGET) +def test_processor_types_delegates(mock_cls: MagicMock) -> None: + mock_ctrl = MagicMock(spec=ListController) + mock_cls.return_value = mock_ctrl + + processor_types_command() + + mock_cls.assert_called_once_with(DATA_DESIGNER_HOME) + mock_ctrl.list_processor_types.assert_called_once_with() diff --git a/packages/data-designer/tests/cli/commands/agent_helpers/test_usage_scenarios.py b/packages/data-designer/tests/cli/commands/agent_helpers/test_usage_scenarios.py new file mode 100644 index 000000000..12c55ddbd --- /dev/null +++ b/packages/data-designer/tests/cli/commands/agent_helpers/test_usage_scenarios.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import re +import types +from pathlib import Path +from unittest.mock import patch + +from typer.testing import CliRunner + +from data_designer.cli.main import app + +runner = CliRunner() + +ANSI_ESCAPE_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") + + +class _AlwaysTTY: + def isatty(self) -> bool: + return True + + +def _normalize_text(text: str) -> str: + without_ansi = ANSI_ESCAPE_RE.sub("", text) + return re.sub(r"\s+", " ", without_ansi).strip().lower() + + +def _write_usage_config(tmp_path: Path) -> Path: + config_path = tmp_path / "usage_config.py" + config_path.write_text( + """from __future__ import annotations + +import data_designer.config as dd + + +def load_config_builder() -> dd.DataDesignerConfigBuilder: + builder = dd.DataDesignerConfigBuilder() + builder.add_column( + dd.SamplerColumnConfig( + name="record_id", + sampler_type=dd.SamplerType.UUID, + params=dd.UUIDSamplerParams(), + ) + ) + builder.add_column( + dd.SamplerColumnConfig( + name="category", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["A", "B", "C"]), + ) + ) + builder.add_column( + dd.ExpressionColumnConfig( + name="summary", + expr="{{ category }}::{{ record_id }}", + ) + ) + return builder +""", + encoding="utf-8", + ) + return config_path + + +def test_usage_preview_non_interactive_shows_records(tmp_path: Path) -> None: + config_path = _write_usage_config(tmp_path) + result = runner.invoke( + app, + ["preview", str(config_path), "--num-records", "3", "--non-interactive"], + color=False, + ) + + normalized = _normalize_text(result.output) + assert result.exit_code == 0 + assert "record 1 of 3" in normalized + assert "record 3 of 3" in normalized + assert "preview complete" in normalized + + +def test_usage_interactive_preview_navigation(tmp_path: Path) -> None: + config_path = _write_usage_config(tmp_path) + fake_sys = types.SimpleNamespace(stdin=_AlwaysTTY(), stdout=_AlwaysTTY()) + + with ( + patch("data_designer.cli.controllers.generation_controller.sys", fake_sys), + patch( + "data_designer.cli.controllers.generation_controller.wait_for_navigation_key", + side_effect=["n", "p", "q"], + ), + ): + result = runner.invoke( + app, + ["preview", str(config_path), "--num-records", "3"], + color=False, + ) + + normalized = _normalize_text(result.output) + assert result.exit_code == 0 + assert "record 1 of 3" in normalized + assert "record 2 of 3" in normalized + assert "done browsing." in normalized + + +def test_usage_validate_unsupported_extension_is_actionable(tmp_path: Path) -> None: + bad_config = tmp_path / "config.txt" + bad_config.write_text("not supported", encoding="utf-8") + + result = runner.invoke(app, ["validate", str(bad_config)], color=False) + normalized = _normalize_text(result.output) + + assert result.exit_code == 1 + assert "unsupported file extension" in normalized + assert "supported extensions" in normalized + + +def test_usage_introspect_unknown_type_error_is_actionable() -> None: + result = runner.invoke(app, ["inspect", "column", "nonexistent"], color=False) + normalized = _normalize_text(result.output) + + assert result.exit_code == 1 + assert "error: unknown column_type" in normalized + assert "available types:" in normalized diff --git a/packages/data-designer/tests/cli/controllers/__init__.py b/packages/data-designer/tests/cli/controllers/__init__.py new file mode 100644 index 000000000..e5725ea5a --- /dev/null +++ b/packages/data-designer/tests/cli/controllers/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer/tests/cli/controllers/test_introspection_controller.py b/packages/data-designer/tests/cli/controllers/test_introspection_controller.py new file mode 100644 index 000000000..195104cef --- /dev/null +++ b/packages/data-designer/tests/cli/controllers/test_introspection_controller.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import click.exceptions +import pytest + +from data_designer.cli.controllers.introspection_controller import IntrospectionController + +# --------------------------------------------------------------------------- +# show_columns +# --------------------------------------------------------------------------- + + +def test_show_columns_list_mode(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_columns(type_name=None) + captured = capsys.readouterr() + assert "llm-text" in captured.out + assert "sampler" in captured.out + + +def test_show_columns_specific_type(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_columns(type_name="llm-text") + captured = capsys.readouterr() + assert "LLMTextColumnConfig" in captured.out + + +def test_show_columns_all(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_columns(type_name="all") + captured = capsys.readouterr() + assert "llm-text" in captured.out + assert "sampler" in captured.out + + +def test_show_columns_nonexistent_type_exits() -> None: + controller = IntrospectionController() + with pytest.raises(click.exceptions.Exit): + controller.show_columns(type_name="nonexistent_type_xyz") + + +# --------------------------------------------------------------------------- +# show_samplers +# --------------------------------------------------------------------------- + + +def test_show_samplers_list(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_samplers(type_name=None) + captured = capsys.readouterr() + assert "category" in captured.out + + +def test_show_samplers_specific(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_samplers(type_name="category") + captured = capsys.readouterr() + assert "sampler_type: category" in captured.out + + +def test_show_samplers_all_case_insensitive(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_samplers(type_name="ALL") + captured = capsys.readouterr() + assert "Data Designer Sampler Types Reference" in captured.out + assert "sampler_type: category" in captured.out + + +# --------------------------------------------------------------------------- +# show_builder +# --------------------------------------------------------------------------- + + +def test_show_builder(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_builder() + captured = capsys.readouterr() + assert "add_column" in captured.out + + +# --------------------------------------------------------------------------- +# show_sampler_constraints +# --------------------------------------------------------------------------- + + +def test_show_sampler_constraints(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_sampler_constraints() + captured = capsys.readouterr() + assert "ScalarInequalityConstraint" in captured.out + + +# --------------------------------------------------------------------------- +# show_validators +# --------------------------------------------------------------------------- + + +def test_show_validators_list_text(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_validators(type_name=None) + captured = capsys.readouterr() + assert "validator_type" in captured.out + assert "params_class" in captured.out + + +def test_show_validators_specific_text(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_validators(type_name="code") + captured = capsys.readouterr() + assert "validator_type: code" in captured.out + + +def test_show_validators_all_case_insensitive(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_validators(type_name="ALL") + captured = capsys.readouterr() + assert "Data Designer Validator Types Reference" in captured.out + assert "validator_type: code" in captured.out + + +# --------------------------------------------------------------------------- +# show_processors +# --------------------------------------------------------------------------- + + +def test_show_processors_list_text(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_processors(type_name=None) + captured = capsys.readouterr() + assert "processor_type" in captured.out + assert "config_class" in captured.out + + +def test_show_processors_specific_type(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_processors(type_name="drop_columns") + captured = capsys.readouterr() + assert "DropColumnsProcessorConfig" in captured.out + + +def test_show_processors_all(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_processors(type_name="all") + captured = capsys.readouterr() + assert "Data Designer Processor Types Reference" in captured.out + assert "processor_type:" in captured.out + + +def test_show_processors_nonexistent() -> None: + controller = IntrospectionController() + with pytest.raises(click.exceptions.Exit): + controller.show_processors(type_name="badname") + + +# --------------------------------------------------------------------------- +# case-insensitive lookup (P1-3) +# --------------------------------------------------------------------------- + + +def test_show_columns_mixed_case(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_columns(type_name="LLM-TEXT") + captured = capsys.readouterr() + assert "LLMTextColumnConfig" in captured.out + + +def test_show_samplers_mixed_case(capsys: pytest.CaptureFixture[str]) -> None: + controller = IntrospectionController() + controller.show_samplers(type_name="CATEGORY") + captured = capsys.readouterr() + assert "sampler_type: category" in captured.out diff --git a/packages/data-designer/tests/cli/controllers/test_list_controller.py b/packages/data-designer/tests/cli/controllers/test_list_controller.py new file mode 100644 index 000000000..293516644 --- /dev/null +++ b/packages/data-designer/tests/cli/controllers/test_list_controller.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from data_designer.cli.controllers.list_controller import ListController + +# --------------------------------------------------------------------------- +# fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def controller(tmp_path: Path) -> ListController: + """Controller with no datasets installed and no model configs.""" + return ListController(tmp_path) + + +@pytest.fixture +def controller_with_datasets(tmp_path: Path) -> ListController: + """Controller with en_US and ja_JP persona datasets installed.""" + managed = tmp_path / "managed-assets" / "datasets" + managed.mkdir(parents=True) + (managed / "en_US.parquet").touch() + (managed / "ja_JP.parquet").touch() + return ListController(tmp_path) + + +@pytest.fixture +def controller_all_installed(tmp_path: Path) -> ListController: + """Controller with ALL managed persona datasets installed.""" + ctrl = ListController(tmp_path) + managed = tmp_path / "managed-assets" / "datasets" + managed.mkdir(parents=True) + for locale in ctrl._persona_repository.list_all(): + (managed / f"{locale.code}.parquet").touch() + return ctrl + + +def _make_model_config(alias: str, model: str, provider: str | None = None) -> MagicMock: + mc = MagicMock() + mc.alias = alias + mc.model = model + mc.provider = provider + return mc + + +def _make_provider(name: str, api_key: str | None = "sk-valid-key") -> MagicMock: + p = MagicMock() + p.name = name + p.api_key = api_key + return p + + +def _make_provider_registry( + providers: list[MagicMock], + default: str | None = None, +) -> MagicMock: + registry = MagicMock() + registry.providers = providers + registry.default = default + return registry + + +# --------------------------------------------------------------------------- +# list_model_aliases — text +# --------------------------------------------------------------------------- + + +def test_model_aliases_text_empty(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry([_make_provider("nvidia")]) + model_reg = MagicMock() + model_reg.model_configs = [] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model aliases configured." in out + assert "data-designer config models" in out + + +def test_model_aliases_text_with_models(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia"), _make_provider("openai")], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("my-model", "meta/llama-3.1-8b-instruct", "nvidia"), + _make_model_config("judge", "openai/gpt-4o", None), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "my-model" in out + assert "meta/llama-3.1-8b-instruct" in out + assert "nvidia" in out + assert "judge" in out + assert "default" in out + + +def test_model_aliases_text_empty_model_configs(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry([_make_provider("nvidia")]) + model_reg = MagicMock() + model_reg.model_configs = [] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model aliases configured." in out + + +# --------------------------------------------------------------------------- +# list_model_aliases — provider validation (text) +# --------------------------------------------------------------------------- + + +def test_model_aliases_text_no_provider_config(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + with patch.object(ctrl._provider_repository, "load", return_value=None): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model providers configured" in out + assert "data-designer config models" in out + + +def test_model_aliases_text_empty_providers(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry([]) + with patch.object(ctrl._provider_repository, "load", return_value=provider_reg): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model providers configured" in out + + +def test_model_aliases_text_all_providers_missing_keys(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [ + _make_provider("nvidia", api_key=None), + _make_provider("openai", api_key=None), + ] + ) + with patch.object(ctrl._provider_repository, "load", return_value=provider_reg): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model providers are configured with valid API keys" in out + assert "data-designer config models" in out + + +def test_model_aliases_text_filters_by_provider(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="sk-valid"), _make_provider("openai", api_key=None)], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("nv-model", "meta/llama-3.1-8b-instruct", "nvidia"), + _make_model_config("oai-model", "openai/gpt-4o", "openai"), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "nv-model" in out + assert "oai-model" not in out + + +def test_model_aliases_text_default_provider_resolution(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Model with provider=None resolves to default provider for filtering.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="sk-valid"), _make_provider("openai", api_key=None)], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("my-model", "meta/llama-3.1-8b-instruct", None), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "my-model" in out + assert "default" in out + + +def test_model_aliases_text_default_provider_resolution_excluded( + tmp_path: Path, capsys: pytest.CaptureFixture[str] +) -> None: + """Model with provider=None is excluded when default provider lacks a valid key.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key=None), _make_provider("openai", api_key="sk-valid")], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("my-model", "meta/llama-3.1-8b-instruct", None), + _make_model_config("oai-model", "openai/gpt-4o", "openai"), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "my-model" not in out + assert "oai-model" in out + + +def test_model_aliases_text_all_models_filtered(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="sk-valid"), _make_provider("openai", api_key=None)], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("oai-model", "openai/gpt-4o", "openai"), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "All configured model aliases use providers without valid API keys" in out + assert "data-designer config models" in out + + +def test_model_aliases_text_default_from_first_provider(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """When provider_registry.default is None, first provider is used as default.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="sk-valid")], + default=None, + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("my-model", "meta/llama-3.1-8b-instruct", None), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "my-model" in out + + +def test_model_aliases_env_var_api_key_set(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Provider whose api_key names an env var that IS set should be treated as valid.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="NVIDIA_API_KEY")], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("nv-model", "meta/llama-3.1-8b-instruct", "nvidia"), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + patch.dict(os.environ, {"NVIDIA_API_KEY": "real-key"}), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "nv-model" in out + + +def test_model_aliases_env_var_api_key_unset(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Provider whose api_key names an env var that is NOT set should be invalid.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="NVIDIA_API_KEY")], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("nv-model", "meta/llama-3.1-8b-instruct", "nvidia"), + ] + env = {k: v for k, v in os.environ.items() if k != "NVIDIA_API_KEY"} + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + patch.dict(os.environ, env, clear=True), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model providers are configured with valid API keys" in out + + +def test_model_aliases_model_registry_returns_none(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """When _model_repository.load() returns None, show 'No model aliases configured.'""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry([_make_provider("nvidia")]) + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=None), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "No model aliases configured." in out + + +def test_model_aliases_multiple_models_same_provider(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """All models on a single valid provider should appear in the output.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="sk-valid")], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("model-a", "meta/llama-3.1-8b-instruct", "nvidia"), + _make_model_config("model-b", "meta/llama-3.1-70b-instruct", "nvidia"), + _make_model_config("model-c", "nvidia/nemotron-4-340b", "nvidia"), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "model-a" in out + assert "model-b" in out + assert "model-c" in out + + +def test_model_aliases_filtered_count_hint(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + """Output should contain a hint about how many aliases were hidden by filtering.""" + ctrl = ListController(tmp_path) + provider_reg = _make_provider_registry( + [_make_provider("nvidia", api_key="sk-valid"), _make_provider("openai", api_key=None)], + default="nvidia", + ) + model_reg = MagicMock() + model_reg.model_configs = [ + _make_model_config("nv-model", "meta/llama-3.1-8b-instruct", "nvidia"), + _make_model_config("oai-model", "openai/gpt-4o", "openai"), + ] + with ( + patch.object(ctrl._provider_repository, "load", return_value=provider_reg), + patch.object(ctrl._model_repository, "load", return_value=model_reg), + ): + ctrl.list_model_aliases() + out = capsys.readouterr().out + assert "nv-model" in out + assert "oai-model" not in out + assert "1 model alias(es) hidden" in out + + +# --------------------------------------------------------------------------- +# list_persona_datasets — text +# --------------------------------------------------------------------------- + + +def test_persona_datasets_text_none_installed(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + controller.list_persona_datasets() + out = capsys.readouterr().out + assert "locale" in out + assert "not installed" in out + + +def test_persona_datasets_text_some_installed( + controller_with_datasets: ListController, capsys: pytest.CaptureFixture[str] +) -> None: + controller_with_datasets.list_persona_datasets() + out = capsys.readouterr().out + assert "en_US" in out + assert "installed" in out + assert "ja_JP" in out + + +def test_persona_datasets_text_all_installed( + controller_all_installed: ListController, capsys: pytest.CaptureFixture[str] +) -> None: + controller_all_installed.list_persona_datasets() + out = capsys.readouterr().out + lines = out.strip().splitlines() + locale_lines = [line for line in lines if "installed" in line and "---" not in line and "status" not in line] + assert len(locale_lines) > 0 + for line in locale_lines: + assert "not installed" not in line + + +# --------------------------------------------------------------------------- +# list_column_types — text +# --------------------------------------------------------------------------- + + +def test_column_types_text(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + controller.list_column_types() + out = capsys.readouterr().out + assert "column_type" in out + assert "config_class" in out + assert "llm-text" in out + assert "sampler" in out + assert "data-designer inspect column" in out + + +# --------------------------------------------------------------------------- +# list_sampler_types — text +# --------------------------------------------------------------------------- + + +def test_sampler_types_text(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + controller.list_sampler_types() + out = capsys.readouterr().out + assert "sampler_type" in out + assert "params_class" in out + assert "category" in out + assert "data-designer inspect sampler" in out + + +# --------------------------------------------------------------------------- +# list_validator_types — text +# --------------------------------------------------------------------------- + + +def test_validator_types_text(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + controller.list_validator_types() + out = capsys.readouterr().out + assert "validator_type" in out + assert "params_class" in out + assert "data-designer inspect validator" in out + + +# --------------------------------------------------------------------------- +# list_processor_types — text +# --------------------------------------------------------------------------- + + +def test_processor_types_text(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + controller.list_processor_types() + out = capsys.readouterr().out + assert "processor_type" in out + assert "config_class" in out + assert "data-designer inspect processor" in out + + +# --------------------------------------------------------------------------- +# list_*_types — empty discovery (P0-1) +# --------------------------------------------------------------------------- + + +def test_list_column_types_empty_discovery(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + with patch("data_designer.cli.controllers.list_controller.discover_column_configs", return_value={}): + controller.list_column_types() + out = capsys.readouterr().out + assert "No items found" in out + + +def test_list_sampler_types_empty_discovery(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + with patch("data_designer.cli.controllers.list_controller.discover_sampler_types", return_value={}): + controller.list_sampler_types() + out = capsys.readouterr().out + assert "No items found" in out + + +def test_list_validator_types_empty_discovery(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + with patch("data_designer.cli.controllers.list_controller.discover_validator_types", return_value={}): + controller.list_validator_types() + out = capsys.readouterr().out + assert "No items found" in out + + +def test_list_processor_types_empty_discovery(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + with patch("data_designer.cli.controllers.list_controller.discover_processor_configs", return_value={}): + controller.list_processor_types() + out = capsys.readouterr().out + assert "No items found" in out + + +# --------------------------------------------------------------------------- +# list_persona_datasets — empty (P0-2) +# --------------------------------------------------------------------------- + + +def test_list_persona_datasets_empty(controller: ListController, capsys: pytest.CaptureFixture[str]) -> None: + with patch.object(controller._persona_repository, "list_all", return_value=[]): + controller.list_persona_datasets() + out = capsys.readouterr().out + assert "No persona datasets found" in out diff --git a/packages/data-designer/tests/cli/services/__init__.py b/packages/data-designer/tests/cli/services/__init__.py new file mode 100644 index 000000000..e5725ea5a --- /dev/null +++ b/packages/data-designer/tests/cli/services/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer/tests/cli/services/introspection/__init__.py b/packages/data-designer/tests/cli/services/introspection/__init__.py new file mode 100644 index 000000000..e5725ea5a --- /dev/null +++ b/packages/data-designer/tests/cli/services/introspection/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer/tests/cli/services/introspection/test_discovery.py b/packages/data-designer/tests/cli/services/introspection/test_discovery.py new file mode 100644 index 000000000..e2ae5a921 --- /dev/null +++ b/packages/data-designer/tests/cli/services/introspection/test_discovery.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import Enum +from typing import Literal + +from data_designer.cli.services.introspection.discovery import ( + _discover_by_modules, + _extract_literal_discriminator_value, + discover_column_configs, + discover_constraint_types, + discover_processor_configs, + discover_sampler_types, + discover_validator_types, +) + +# --------------------------------------------------------------------------- +# discover_column_configs +# --------------------------------------------------------------------------- + + +def test_discover_column_configs_returns_dict() -> None: + result = discover_column_configs() + assert isinstance(result, dict) + assert len(result) > 0 + + +def test_discover_column_configs_contains_expected_keys() -> None: + result = discover_column_configs() + for expected_key in ("llm-text", "sampler", "expression"): + assert expected_key in result, f"Expected key '{expected_key}' not found in {list(result.keys())}" + + +def test_discover_column_configs_values_are_classes() -> None: + result = discover_column_configs() + for cls in result.values(): + assert isinstance(cls, type) + assert hasattr(cls, "model_fields") + + +# --------------------------------------------------------------------------- +# discover_sampler_types +# --------------------------------------------------------------------------- + + +def test_discover_sampler_types_returns_dict() -> None: + result = discover_sampler_types() + assert isinstance(result, dict) + assert len(result) > 0 + + +def test_discover_sampler_types_contains_expected_keys() -> None: + result = discover_sampler_types() + for expected_key in ("category", "uniform", "person"): + assert expected_key in result, f"Expected key '{expected_key}' not found in {list(result.keys())}" + + +# --------------------------------------------------------------------------- +# discover_validator_types +# --------------------------------------------------------------------------- + + +def test_discover_validator_types_returns_dict() -> None: + result = discover_validator_types() + assert isinstance(result, dict) + assert len(result) > 0 + + +def test_discover_validator_types_contains_expected_keys() -> None: + result = discover_validator_types() + for expected_key in ("code", "remote"): + assert expected_key in result, f"Expected key '{expected_key}' not found in {list(result.keys())}" + + +# --------------------------------------------------------------------------- +# discover_processor_configs +# --------------------------------------------------------------------------- + + +def test_discover_processor_configs_returns_dict() -> None: + result = discover_processor_configs() + assert isinstance(result, dict) + assert len(result) > 0 + + +def test_discover_processor_configs_contains_expected_keys() -> None: + result = discover_processor_configs() + assert "drop_columns" in result, f"Expected 'drop_columns' not found in {list(result.keys())}" + + +# --------------------------------------------------------------------------- +# discover_constraint_types +# --------------------------------------------------------------------------- + + +def test_discover_constraint_types_returns_dict() -> None: + result = discover_constraint_types() + assert isinstance(result, dict) + assert len(result) > 0 + + +def test_discover_constraint_types_contains_expected_keys() -> None: + result = discover_constraint_types() + assert "ScalarInequalityConstraint" in result + + +# --------------------------------------------------------------------------- +# _discover_by_modules +# --------------------------------------------------------------------------- + + +def test_discover_by_modules_returns_only_matching_modules() -> None: + result = _discover_by_modules("models") + import data_designer.config as dd + + lazy_imports: dict[str, tuple[str, str]] = getattr(dd, "_LAZY_IMPORTS", {}) + model_names = {name for name, (mod, _) in lazy_imports.items() if mod == "data_designer.config.models"} + assert set(result.keys()) == model_names + + +def test_discover_by_modules_with_multiple_suffixes() -> None: + result = _discover_by_modules("seed", "seed_source") + assert "SeedConfig" in result + assert "LocalFileSeedSource" in result + + +def test_discover_by_modules_unknown_suffix_returns_empty() -> None: + result = _discover_by_modules("nonexistent_module") + assert result == {} + + +# --------------------------------------------------------------------------- +# _extract_literal_discriminator_value (P1-5) +# --------------------------------------------------------------------------- + + +class _TestEnum(str, Enum): + A = "alpha" + B = "beta" + + +def test_extract_literal_value_string() -> None: + assert _extract_literal_discriminator_value(Literal["foo"]) == "foo" + + +def test_extract_literal_value_enum() -> None: + result = _extract_literal_discriminator_value(Literal[_TestEnum.A]) + assert result == "alpha" + + +def test_extract_literal_non_literal() -> None: + assert _extract_literal_discriminator_value(str) is None + + +def test_extract_literal_int_value() -> None: + assert _extract_literal_discriminator_value(Literal[42]) == "42" diff --git a/packages/data-designer/tests/cli/services/introspection/test_field_descriptions.py b/packages/data-designer/tests/cli/services/introspection/test_field_descriptions.py new file mode 100644 index 000000000..edfbec7ab --- /dev/null +++ b/packages/data-designer/tests/cli/services/introspection/test_field_descriptions.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.cli.services.introspection.discovery import ( + discover_column_configs, + discover_constraint_types, + discover_processor_configs, + discover_sampler_types, + discover_validator_types, +) + + +def _collect_models_with_fields() -> list[tuple[str, str, type]]: + """Collect all discovered model classes and their fields. + + Returns: + List of (source_label, field_name, model_class) tuples. + """ + items: list[tuple[str, str, type]] = [] + + discovery_sources: list[tuple[str, dict[str, type]]] = [ + ("column_configs", discover_column_configs()), + ("sampler_types", discover_sampler_types()), + ("validator_types", discover_validator_types()), + ("processor_configs", discover_processor_configs()), + ("constraint_types", discover_constraint_types()), + ] + + for source_label, discovered in discovery_sources: + for type_name, cls in discovered.items(): + if not hasattr(cls, "model_fields"): + continue + for field_name in cls.model_fields: + items.append((f"{source_label}:{type_name}", field_name, cls)) + + return items + + +_ALL_FIELDS = _collect_models_with_fields() + + +@pytest.mark.parametrize( + "source_label,field_name,cls", + _ALL_FIELDS, + ids=[f"{src}.{field}" for src, field, _ in _ALL_FIELDS], +) +def test_all_discovered_fields_have_descriptions(source_label: str, field_name: str, cls: type) -> None: + """Every field in discovered config models must have a non-empty description.""" + field_info = cls.model_fields[field_name] + assert field_info.description, ( + f"{cls.__name__}.{field_name} (from {source_label}) has no Field(description=...). " + f"Add a description to this field in the source model." + ) diff --git a/packages/data-designer/tests/cli/services/introspection/test_formatters.py b/packages/data-designer/tests/cli/services/introspection/test_formatters.py new file mode 100644 index 000000000..2bb19f2b1 --- /dev/null +++ b/packages/data-designer/tests/cli/services/introspection/test_formatters.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.cli.services.introspection.formatters import ( + format_method_info_text, + format_type_list_text, +) +from data_designer.cli.services.introspection.method_inspector import MethodInfo, ParamInfo + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_method( + name: str = "do_thing", + signature: str = "do_thing(x: int) -> str", + description: str = "Does a thing.", + return_type: str = "str", + parameters: list[ParamInfo] | None = None, +) -> MethodInfo: + return MethodInfo( + name=name, + signature=signature, + description=description, + return_type=return_type, + parameters=parameters or [ParamInfo(name="x", type_str="int", default=None, description="An integer")], + ) + + +# --------------------------------------------------------------------------- +# format_method_info_text +# --------------------------------------------------------------------------- + + +def test_format_method_info_text_basic() -> None: + methods = [_make_method()] + text = format_method_info_text(methods) + assert "do_thing(x: int) -> str" in text + assert "Does a thing." in text + assert "Parameters:" in text + + +def test_format_method_info_text_with_class_name() -> None: + methods = [_make_method()] + text = format_method_info_text(methods, class_name="MyBuilder") + assert "MyBuilder Methods:" in text + + +def test_format_method_info_text_no_class_name() -> None: + methods = [_make_method()] + text = format_method_info_text(methods, class_name=None) + assert "Methods:" not in text + + +# --------------------------------------------------------------------------- +# format_type_list_text +# --------------------------------------------------------------------------- + + +def test_format_type_list_text_basic() -> None: + class FakeA: + pass + + class FakeB: + pass + + items: dict[str, type] = {"alpha": FakeA, "beta": FakeB} + text = format_type_list_text(items, "type_name", "class_name") + assert "type_name" in text + assert "class_name" in text + assert "alpha" in text + assert "FakeA" in text + assert "beta" in text + assert "FakeB" in text + + +def test_format_type_list_text_alignment() -> None: + class C: + pass + + items: dict[str, type] = {"short": C, "very_long_name": C} + text = format_type_list_text(items, "Type", "Class") + lines = text.strip().split("\n") + # Header + separator + 2 data rows + assert len(lines) == 4 + + +def test_format_type_list_text_empty() -> None: + text = format_type_list_text({}, "Type", "Class") + assert "(no items)" in text + + +# --------------------------------------------------------------------------- +# format_method_info_text — edge cases (P1-7) +# --------------------------------------------------------------------------- + + +def test_format_method_info_text_empty_list() -> None: + text = format_method_info_text([], class_name="MyClass") + assert "MyClass Methods:" in text + lines = text.strip().split("\n") + assert len(lines) <= 2 + + +def test_format_method_info_text_no_description() -> None: + method = MethodInfo( + name="do_thing", + signature="do_thing() -> None", + description="", + return_type="None", + parameters=[], + ) + text = format_method_info_text([method]) + lines = text.strip().split("\n") + sig_line_idx = next(i for i, line in enumerate(lines) if "do_thing()" in line) + if sig_line_idx + 1 < len(lines): + next_line = lines[sig_line_idx + 1].strip() + assert next_line == "" or next_line.startswith("Parameters:") or "do_thing" not in next_line + + +def test_format_method_info_text_no_parameters() -> None: + method = MethodInfo( + name="do_thing", + signature="do_thing() -> None", + description="Does a thing.", + return_type="None", + parameters=[], + ) + text = format_method_info_text([method]) + assert "Parameters:" not in text diff --git a/packages/data-designer/tests/cli/services/introspection/test_method_inspector.py b/packages/data-designer/tests/cli/services/introspection/test_method_inspector.py new file mode 100644 index 000000000..8d1670208 --- /dev/null +++ b/packages/data-designer/tests/cli/services/introspection/test_method_inspector.py @@ -0,0 +1,361 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.cli.services.introspection.method_inspector import ( + MethodInfo, + _is_dunder, + _is_private, + _parse_google_docstring_args, + inspect_class_methods, +) + +# --------------------------------------------------------------------------- +# Test helper classes +# --------------------------------------------------------------------------- + + +class SampleClass: + """A sample class for testing method introspection.""" + + def public_method(self, x: int, y: str = "default") -> str: + """Do something public. + + Args: + x: The integer input. + y: An optional string. + + Returns: + A result string. + """ + return f"{x}-{y}" + + def another_public(self) -> None: + """Another public method with no args.""" + + def _private_method(self, z: float) -> float: + """A private helper. + + Args: + z: A float value. + """ + return z * 2 + + def __dunder_method__(self) -> None: + """Should be excluded (dunder).""" + + def __init__(self) -> None: + """Init should be included.""" + + +class ClassWithClassmethod: + """A class with a classmethod for testing.""" + + @classmethod + def from_value(cls, value: int) -> ClassWithClassmethod: + """Create an instance from a value. + + Args: + value: The input value. + """ + return cls() + + def regular_method(self) -> str: + """A regular method.""" + return "hello" + + +class ClassWithDefaultInitDocstring: + """A useful class that does important things. + + This is a longer description of the class. + """ + + def __init__(self, x: int = 0) -> None: + self.x = x + + +class ClassWithCustomInitDocstring: + """Class-level docstring.""" + + def __init__(self, x: int) -> None: + """Custom init docstring. + + Args: + x: An integer. + """ + self.x = x + + +# --------------------------------------------------------------------------- +# _parse_google_docstring_args +# --------------------------------------------------------------------------- + + +def test_parse_google_docstring_args_basic() -> None: + docstring = """Do something. + + Args: + x: The first parameter. + y: The second parameter. + + Returns: + A result. + """ + result = _parse_google_docstring_args(docstring) + assert "x" in result + assert result["x"] == "The first parameter." + assert "y" in result + assert result["y"] == "The second parameter." + + +def test_parse_google_docstring_args_empty() -> None: + assert _parse_google_docstring_args(None) == {} + assert _parse_google_docstring_args("") == {} + + +def test_parse_google_docstring_args_no_args_section() -> None: + docstring = """Just a description. + + Returns: + Something. + """ + result = _parse_google_docstring_args(docstring) + assert result == {} + + +def test_parse_google_docstring_args_multiline_description() -> None: + docstring = """Do something. + + Args: + x: First line of description + continued on second line. + y: Another param. + """ + result = _parse_google_docstring_args(docstring) + assert "x" in result + assert "continued" in result["x"] + assert "y" in result + + +# --------------------------------------------------------------------------- +# inspect_class_methods - exclude private +# --------------------------------------------------------------------------- + + +def test_inspect_class_methods_public_only() -> None: + methods = inspect_class_methods(SampleClass, include_private=False) + names = [m.name for m in methods] + assert "public_method" in names + assert "another_public" in names + assert "_private_method" not in names + assert "__dunder_method__" not in names + + +def test_inspect_class_methods_returns_method_info() -> None: + methods = inspect_class_methods(SampleClass, include_private=False) + assert all(isinstance(m, MethodInfo) for m in methods) + + +def test_inspect_class_methods_signature_content() -> None: + methods = inspect_class_methods(SampleClass, include_private=False) + public = next(m for m in methods if m.name == "public_method") + assert "x: int" in public.signature + assert "y: str" in public.signature + assert "str" in public.return_type + + +def test_inspect_class_methods_description() -> None: + methods = inspect_class_methods(SampleClass, include_private=False) + public = next(m for m in methods if m.name == "public_method") + assert public.description == "Do something public." + + +def test_inspect_class_methods_parameters() -> None: + methods = inspect_class_methods(SampleClass, include_private=False) + public = next(m for m in methods if m.name == "public_method") + param_names = [p.name for p in public.parameters] + assert "x" in param_names + assert "y" in param_names + x_param = next(p for p in public.parameters if p.name == "x") + assert x_param.description == "The integer input." + + +# --------------------------------------------------------------------------- +# inspect_class_methods - include private +# --------------------------------------------------------------------------- + + +def test_inspect_class_methods_include_private() -> None: + methods = inspect_class_methods(SampleClass, include_private=True) + names = [m.name for m in methods] + assert "_private_method" in names + assert "__dunder_method__" not in names + + +def test_inspect_class_methods_init_included() -> None: + methods = inspect_class_methods(SampleClass, include_private=True) + names = [m.name for m in methods] + assert "__init__" in names + + +# --------------------------------------------------------------------------- +# inspect_class_methods - classmethod detection +# --------------------------------------------------------------------------- + + +def test_inspect_class_methods_detects_classmethod() -> None: + methods = inspect_class_methods(ClassWithClassmethod, include_private=False) + names = [m.name for m in methods] + assert "from_value" in names + assert "regular_method" in names + + +def test_inspect_class_methods_classmethod_signature() -> None: + methods = inspect_class_methods(ClassWithClassmethod, include_private=False) + from_value = next(m for m in methods if m.name == "from_value") + assert "value: int" in from_value.signature + + +def test_inspect_class_methods_classmethod_description() -> None: + methods = inspect_class_methods(ClassWithClassmethod, include_private=False) + from_value = next(m for m in methods if m.name == "from_value") + assert from_value.description == "Create an instance from a value." + + +def test_inspect_class_methods_classmethod_parameters() -> None: + methods = inspect_class_methods(ClassWithClassmethod, include_private=False) + from_value = next(m for m in methods if m.name == "from_value") + param_names = [p.name for p in from_value.parameters] + assert "value" in param_names + value_param = next(p for p in from_value.parameters if p.name == "value") + assert value_param.description == "The input value." + + +# --------------------------------------------------------------------------- +# __init__ docstring fallback +# --------------------------------------------------------------------------- + + +def test_init_default_docstring_falls_back_to_class() -> None: + methods = inspect_class_methods(ClassWithDefaultInitDocstring, include_private=True) + init = next((m for m in methods if m.name == "__init__"), None) + assert init is not None + assert init.description == "A useful class that does important things." + + +def test_init_custom_docstring_preserved() -> None: + methods = inspect_class_methods(ClassWithCustomInitDocstring, include_private=True) + init = next((m for m in methods if m.name == "__init__"), None) + assert init is not None + assert init.description == "Custom init docstring." + + +# --------------------------------------------------------------------------- +# inspect_class_methods — edge cases (P1-4) +# --------------------------------------------------------------------------- + + +class EmptyClass: + """A class with no public methods (no __init__ either).""" + + +class ClassWithBadSignature: + """A class where one method has an uninspectable signature.""" + + def good_method(self) -> str: + """Works fine.""" + return "ok" + + +class ClassWithVarArgs: + """A class with *args and **kwargs.""" + + def method_with_varargs(self, *args: str, **kwargs: int) -> None: + """A method with varargs.""" + + +def test_inspect_class_methods_empty_class() -> None: + methods = inspect_class_methods(EmptyClass, include_private=False) + assert methods == [] + + +def test_inspect_class_methods_signature_error_skipped() -> None: + import inspect as _inspect + from unittest.mock import patch + + original_sig = _inspect.signature + + def bad_signature(method: object) -> _inspect.Signature: + if getattr(method, "__name__", "") == "good_method": + raise ValueError("cannot inspect") + return original_sig(method) + + with patch( + "data_designer.cli.services.introspection.method_inspector.inspect.signature", side_effect=bad_signature + ): + methods = inspect_class_methods(ClassWithBadSignature, include_private=False) + + names = [m.name for m in methods] + assert "good_method" not in names + + +def test_inspect_class_methods_varargs_and_kwargs() -> None: + methods = inspect_class_methods(ClassWithVarArgs, include_private=False) + m = next(m for m in methods if m.name == "method_with_varargs") + assert "*args" in m.signature + assert "**kwargs" in m.signature + + +# --------------------------------------------------------------------------- +# _is_dunder / _is_private (P2-1) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("__init__", False), + ("__str__", True), + ("__repr__", True), + ("regular", False), + ("_private", False), + ], +) +def test_is_dunder(name: str, expected: bool) -> None: + assert _is_dunder(name) is expected + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("_foo", True), + ("_private_method", True), + ("__init__", False), + ("__str__", False), + ("public", False), + ], +) +def test_is_private(name: str, expected: bool) -> None: + assert _is_private(name) is expected + + +# --------------------------------------------------------------------------- +# keyword-only params (P2-9) +# --------------------------------------------------------------------------- + + +class ClassWithKeywordOnly: + """A class with keyword-only parameters.""" + + def method_with_kw(self, *, kw: str = "x") -> None: + """A method with keyword-only arg.""" + + +def test_format_signature_keyword_only() -> None: + methods = inspect_class_methods(ClassWithKeywordOnly, include_private=False) + m = next(m for m in methods if m.name == "method_with_kw") + assert "*, " in m.signature or "*," in m.signature diff --git a/packages/data-designer/tests/cli/services/introspection/test_pydantic_inspector.py b/packages/data-designer/tests/cli/services/introspection/test_pydantic_inspector.py new file mode 100644 index 000000000..0889021e9 --- /dev/null +++ b/packages/data-designer/tests/cli/services/introspection/test_pydantic_inspector.py @@ -0,0 +1,540 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import Enum +from typing import Annotated + +import pytest +from pydantic import BaseModel, Field + +from data_designer.cli.services.introspection.pydantic_inspector import ( + _default_to_json, + _extract_constraints, + _extract_enum_class, + _extract_nested_basemodel, + _is_basemodel_subclass, + _is_enum_subclass, + format_model_text, + format_type, + get_brief_description, +) + +# --------------------------------------------------------------------------- +# Test models / enums +# --------------------------------------------------------------------------- + + +class ColorEnum(str, Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +class SizeEnum(str, Enum): + SMALL = "small" + LARGE = "large" + + +class InnerModel(BaseModel): + x: int = 0 + y: str = "hello" + + +class OuterModel(BaseModel): + """Outer model for testing.""" + + plain: str = "foo" + nested: InnerModel = Field(default_factory=InnerModel) + my_enum: ColorEnum = ColorEnum.RED + + +class SelfRefModel(BaseModel): + """A model that references itself (for cycle testing).""" + + name: str = "" + child: SelfRefModel | None = None + + +class DeepB(BaseModel): + val: int = 0 + + +class DeepA(BaseModel): + val: int = 0 + b: DeepB | None = None + + +class SiblingNestedModel(BaseModel): + first: InnerModel = Field(default_factory=InnerModel) + second: InnerModel = Field(default_factory=InnerModel) + + +# Rebuild models that use forward references (required due to `from __future__ import annotations`) +SelfRefModel.model_rebuild() +DeepA.model_rebuild() + + +class RequiredFieldModel(BaseModel): + """Model with required and optional fields for testing.""" + + required_name: str + optional_name: str = "default_val" + + +class ConstrainedModel(BaseModel): + """Model with constrained fields for testing.""" + + score: float = Field(default=0.5, ge=0.0, le=1.0) + label: str = Field(default="", min_length=1, max_length=100) + count: int = Field(default=0, gt=-1, lt=1000) + + +# --------------------------------------------------------------------------- +# _is_basemodel_subclass +# --------------------------------------------------------------------------- + + +def test_is_basemodel_subclass_with_subclass() -> None: + assert _is_basemodel_subclass(InnerModel) is True + + +def test_is_basemodel_subclass_with_basemodel_itself() -> None: + assert _is_basemodel_subclass(BaseModel) is False + + +def test_is_basemodel_subclass_with_str() -> None: + assert _is_basemodel_subclass(str) is False + + +def test_is_basemodel_subclass_with_enum() -> None: + assert _is_basemodel_subclass(ColorEnum) is False + + +def test_is_basemodel_subclass_with_non_type() -> None: + assert _is_basemodel_subclass("not a type") is False + + +# --------------------------------------------------------------------------- +# _is_enum_subclass +# --------------------------------------------------------------------------- + + +def test_is_enum_subclass_with_enum_subclass() -> None: + assert _is_enum_subclass(ColorEnum) is True + + +def test_is_enum_subclass_with_enum_itself() -> None: + assert _is_enum_subclass(Enum) is False + + +def test_is_enum_subclass_with_str() -> None: + assert _is_enum_subclass(str) is False + + +def test_is_enum_subclass_with_non_type() -> None: + assert _is_enum_subclass(42) is False + + +# --------------------------------------------------------------------------- +# _extract_enum_class +# --------------------------------------------------------------------------- + + +def test_extract_enum_class_direct_enum() -> None: + assert _extract_enum_class(ColorEnum) is ColorEnum + + +def test_extract_enum_class_optional_enum() -> None: + assert _extract_enum_class(ColorEnum | None) is ColorEnum + + +def test_extract_enum_class_annotated_enum() -> None: + assert _extract_enum_class(Annotated[ColorEnum, "metadata"]) is ColorEnum + + +def test_extract_enum_class_non_enum() -> None: + assert _extract_enum_class(str) is None + + +def test_extract_enum_class_none() -> None: + assert _extract_enum_class(None) is None + + +# --------------------------------------------------------------------------- +# extract_nested_basemodel +# --------------------------------------------------------------------------- + + +def test_extract_nested_basemodel_direct() -> None: + assert _extract_nested_basemodel(InnerModel) is InnerModel + + +def test_extract_nested_basemodel_list() -> None: + assert _extract_nested_basemodel(list[InnerModel]) is InnerModel + + +def test_extract_nested_basemodel_optional() -> None: + assert _extract_nested_basemodel(InnerModel | None) is InnerModel + + +def test_extract_nested_basemodel_optional_list() -> None: + assert _extract_nested_basemodel(list[InnerModel] | None) is InnerModel + + +def test_extract_nested_basemodel_dict() -> None: + assert _extract_nested_basemodel(dict[str, InnerModel]) is InnerModel + + +def test_extract_nested_basemodel_annotated() -> None: + assert _extract_nested_basemodel(Annotated[InnerModel, "info"]) is InnerModel + + +def test_extract_nested_basemodel_discriminated_union_returns_none() -> None: + """Unions of 2+ BaseModel subclasses should return None.""" + assert _extract_nested_basemodel(InnerModel | OuterModel) is None + + +def test_extract_nested_basemodel_primitive_returns_none() -> None: + assert _extract_nested_basemodel(str) is None + assert _extract_nested_basemodel(int) is None + + +def test_extract_nested_basemodel_none_returns_none() -> None: + assert _extract_nested_basemodel(None) is None + + +def test_extract_nested_basemodel_basemodel_itself_returns_none() -> None: + assert _extract_nested_basemodel(BaseModel) is None + + +# --------------------------------------------------------------------------- +# format_type +# --------------------------------------------------------------------------- + + +def test_format_type_str() -> None: + result = format_type(str) + assert "str" in result + + +def test_format_type_int() -> None: + result = format_type(int) + assert "int" in result + + +def test_format_type_optional() -> None: + result = format_type(str | None) + assert "str" in result + assert "None" in result + + +# --------------------------------------------------------------------------- +# get_brief_description +# --------------------------------------------------------------------------- + + +def test_get_brief_description_with_docstring() -> None: + result = get_brief_description(OuterModel) + assert result == "Outer model for testing." + + +def test_get_brief_description_without_docstring() -> None: + result = get_brief_description(InnerModel) + assert result == "No description available." + + +# --------------------------------------------------------------------------- +# _extract_constraints +# --------------------------------------------------------------------------- + + +def test_extract_constraints_from_constrained_model() -> None: + score_info = ConstrainedModel.model_fields["score"] + constraints = _extract_constraints(score_info) + assert constraints is not None + assert constraints["ge"] == 0.0 + assert constraints["le"] == 1.0 + + +def test_extract_constraints_gt_lt() -> None: + count_info = ConstrainedModel.model_fields["count"] + constraints = _extract_constraints(count_info) + assert constraints is not None + assert constraints["gt"] == -1 + assert constraints["lt"] == 1000 + + +def test_extract_constraints_string_lengths() -> None: + label_info = ConstrainedModel.model_fields["label"] + constraints = _extract_constraints(label_info) + assert constraints is not None + assert constraints["min_length"] == 1 + assert constraints["max_length"] == 100 + + +def test_extract_constraints_none_for_unconstrained() -> None: + x_info = InnerModel.model_fields["x"] + assert _extract_constraints(x_info) is None + + +def test_extract_constraints_helper_with_no_metadata() -> None: + """_extract_constraints returns None when field_info has no constraint metadata.""" + + class FakeFieldInfo: + metadata: list = [] + + assert _extract_constraints(FakeFieldInfo()) is None + + +# --------------------------------------------------------------------------- +# format_model_text +# --------------------------------------------------------------------------- + + +def test_format_model_text_basic_structure() -> None: + text = format_model_text(OuterModel) + assert "OuterModel:" in text + assert "description: Outer model for testing." in text + assert "fields:" in text + assert "plain:" in text + assert "nested:" in text + assert "my_enum:" in text + + +def test_format_model_text_with_type_key_and_value() -> None: + text = format_model_text(OuterModel, type_key="column_type", type_value="test") + assert "column_type: test" in text + + +def test_format_model_text_required_field() -> None: + text = format_model_text(RequiredFieldModel) + assert "required_name: str [required]" in text + + +def test_format_model_text_optional_field_default() -> None: + text = format_model_text(RequiredFieldModel) + assert "optional_name: str = 'default_val'" in text + assert "[required]" not in text.split("optional_name")[1].split("\n")[0] + + +def test_format_model_text_default_factory() -> None: + text = format_model_text(OuterModel) + assert "= InnerModel()" in text + + +def test_format_model_text_none_default() -> None: + text = format_model_text(SelfRefModel) + assert "child:" in text + assert "= None" in text + + +def test_format_model_text_enum_default_uses_member_value() -> None: + text = format_model_text(OuterModel) + assert "my_enum: ColorEnum = 'red'" in text + + +def test_format_model_text_enum_values() -> None: + text = format_model_text(OuterModel) + assert "values: [red, green, blue]" in text + + +def test_format_model_text_constraints() -> None: + text = format_model_text(ConstrainedModel) + assert "constraints: ge=0.0, le=1.0" in text + + +def test_format_model_text_nested_expansion() -> None: + text = format_model_text(OuterModel) + assert "schema (InnerModel):" in text + # Nested fields should appear indented under the schema + assert "x: int = 0" in text + assert "y: str = 'hello'" in text + + +def test_format_model_text_cycle_protection() -> None: + text = format_model_text(SelfRefModel) + # First level should expand + assert "schema (SelfRefModel):" in text + # The recursive child.child should NOT expand again (only one "schema (SelfRefModel):") + assert text.count("schema (SelfRefModel):") == 1 + + +def test_format_model_text_depth_limiting() -> None: + text = format_model_text(DeepA, max_depth=1) + # First level (DeepB) should expand + assert "schema (DeepB):" in text + + +def test_format_model_text_sibling_nested_expands_each() -> None: + """Sibling fields of the same nested type should each include a nested schema.""" + text = format_model_text(SiblingNestedModel) + # Both first and second fields should have InnerModel expanded + assert text.count("schema (InnerModel):") == 2 + + +def test_format_model_text_deduplication_with_seen_schemas() -> None: + """When seen_schemas is passed across calls, second occurrence shows a back-reference.""" + seen: set[str] = set() + text1 = format_model_text(OuterModel, seen_schemas=seen) + text2 = format_model_text(SiblingNestedModel, seen_schemas=seen) + assert "schema (InnerModel):" in text1 + assert "see InnerModel above" in text2 + + +def test_format_model_text_no_dedup_without_seen_set() -> None: + """Without seen_schemas, nested schemas always expand fully.""" + text = format_model_text(OuterModel) + assert "schema (InnerModel):" in text + + +def test_format_model_text_max_depth_zero_blocks_all_nesting() -> None: + """At max_depth=0, nested schemas should not expand.""" + text = format_model_text(OuterModel, max_depth=0) + assert "schema (InnerModel):" not in text + assert "nested:" in text # field still listed, just not expanded + + +def test_format_model_text_dedup_distinguishes_same_name_different_module() -> None: + """Schemas with same __name__ but different __module__ should not dedup.""" + + class SharedNameA(BaseModel): + x: int = 0 + + class SharedNameB(BaseModel): + y: str = "" + + # Make them look like same-named classes from different modules + SharedNameB.__name__ = "SharedNameA" + SharedNameB.__qualname__ = "SharedNameA" + SharedNameA.__module__ = "pkg.alpha" + SharedNameB.__module__ = "pkg.beta" + + class WrapperA(BaseModel): + a: SharedNameA = Field(default_factory=SharedNameA) + + class WrapperB(BaseModel): + b: SharedNameB = Field(default_factory=SharedNameB) + + WrapperA.model_rebuild() + WrapperB.model_rebuild() + + seen: set[str] = set() + text_a = format_model_text(WrapperA, seen_schemas=seen) + text_b = format_model_text(WrapperB, seen_schemas=seen) + + assert "schema (SharedNameA):" in text_a + assert "schema (SharedNameA):" in text_b + assert "see SharedNameA above" not in text_b + + +class Level3(BaseModel): + val: int = 0 + + +class Level2(BaseModel): + val: int = 0 + child: Level3 | None = None + + +class Level1(BaseModel): + val: int = 0 + child: Level2 | None = None + + +Level1.model_rebuild() +Level2.model_rebuild() + + +def test_format_model_text_depth_limiting_blocks_deeper_nesting() -> None: + """With max_depth=1, Level2 expands but Level3 does not.""" + text = format_model_text(Level1, max_depth=1) + assert "schema (Level2):" in text + assert "schema (Level3):" not in text + + +# --------------------------------------------------------------------------- +# _default_to_json (P1-6) +# --------------------------------------------------------------------------- + + +class _JsonTestEnum(str, Enum): + MEMBER = "member_value" + + +class _CustomObj: + def __repr__(self) -> str: + return "CustomObj()" + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + (_JsonTestEnum.MEMBER, "member_value"), + (True, True), + (42, 42), + (3.14, 3.14), + ("hello", "hello"), + ([1, 2], [1, 2]), + ({"a": 1}, {"a": 1}), + ], +) +def test_default_to_json(value: object, expected: object) -> None: + assert _default_to_json(value) == expected + + +def test_default_to_json_custom_object() -> None: + obj = _CustomObj() + assert _default_to_json(obj) == "CustomObj()" + + +# --------------------------------------------------------------------------- +# format_type — regex branches (P1-8) +# --------------------------------------------------------------------------- + + +def test_format_type_none_type() -> None: + result = format_type(type(None)) + assert result == "None" + + +def test_format_type_enum_class() -> None: + result = format_type(ColorEnum) + assert result == "ColorEnum" + + +def test_format_type_module_prefix_stripping() -> None: + import data_designer.config as dd + + result = format_type(list[dd.CategorySamplerParams]) + assert "data_designer.config." not in result + assert "CategorySamplerParams" in result + + +def test_format_type_literal() -> None: + from typing import Literal + + result = format_type(Literal["foo", "bar"]) + assert "Literal[" in result + assert "foo" in result + assert "bar" in result + + +# --------------------------------------------------------------------------- +# format_model_text — empty model (P1-10) +# --------------------------------------------------------------------------- + + +class EmptyModel(BaseModel): + """An empty model with no fields.""" + + +def test_format_model_text_empty_model() -> None: + text = format_model_text(EmptyModel) + assert "EmptyModel:" in text + assert "fields:" in text + lines = text.strip().split("\n") + field_lines = [line for line in lines if line.startswith(" ") and ":" in line and "fields:" not in line] + assert len(field_lines) == 0