Skip to content

Add input validation in predict() for feature shape and NaN handling#561

Open
2024itb047samata wants to merge 6 commits into
dswah:mainfrom
2024itb047samata:fix-predict-validation
Open

Add input validation in predict() for feature shape and NaN handling#561
2024itb047samata wants to merge 6 commits into
dswah:mainfrom
2024itb047samata:fix-predict-validation

Conversation

@2024itb047samata

Copy link
Copy Markdown

This PR improves the robustness of the predict() method by adding input validation for feature shape and NaN values.

  • Ensures input X is converted to a 2D NumPy array
  • Automatically reshapes 1D input to (n_samples, 1)
  • Validates that the number of features matches the trained model (n_features_)
  • Raises a clear error if input contains NaN values

Previously, predict() could silently accept invalid inputs (e.g., incorrect shape or NaN values), leading to unexpected behavior or errors deeper in the pipeline. This change aligns the behavior with standard practices in libraries like scikit-learn and improves usability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant