Skip to content

feature: support native YOLO .pt models while ensuring compatibility with Torchvision models#495

Open
kashtennyson wants to merge 2 commits intoJdeRobot:masterfrom
kashtennyson:issue-449
Open

feature: support native YOLO .pt models while ensuring compatibility with Torchvision models#495
kashtennyson wants to merge 2 commits intoJdeRobot:masterfrom
kashtennyson:issue-449

Conversation

@kashtennyson
Copy link
Copy Markdown

Description

This PR adds support for loading native Ultralytics YOLOv8 .pt models while ensuring a consistent interface for the rest of the library. This is a fix for #449

The Problem:
Native YOLO .pt models often return a tuple (inference_tensor, loss_tensor) rather than a raw tensor, which causes "too many values to unpack" errors in the inference and eval methods. Additionally, these models frequently use float16 (Half) precision, leading to DType mismatches with input images or NMS kernel errors on certain backends.

The Solution:
Following previous feedback, I have centralized the fix within the TorchImageDetectionModel class. I implemented a local Adapter class (DetectionModelWrapper) that standardizes the model's behavior at the source:

  • Tuple Unpacking: Automatically extracts the primary detection tensor.
  • Input Alignment: Automatically casts input images to match the model's native dtype (fixing "Float vs Half" errors).
  • Output Alignment: Ensures results are returned as float32 to maintain compatibility with torchvision.ops.nms.
  • Graceful Fallback: Wrapped the .pt loading logic to provide a clear error message suggesting the installation of ultralytics if it is missing.

This PR Supersedes #469. It implements a more stable version by ensuring compatibility with Torchvision models along with the Ultralytics YOLO models.


Architectural Question for Maintainers

"I have implemented the DetectionModelWrapper as a local class within the __init__ method of TorchImageDetectionModel to keep the fix strictly within the requested section and ensure that the normalization is context-specific to the model instance.

Do you prefer this local encapsulation, or would you like me to refactor the wrapper into a private, module-level class (e.g., _ModelNormalizationWrapper) at the top of the file to keep the __init__ method more concise?"

@dpascualhe dpascualhe self-requested a review March 25, 2026 19:08
@dpascualhe dpascualhe self-assigned this Mar 25, 2026
@dpascualhe
Copy link
Copy Markdown
Collaborator

Hi, thanks for your contribution! I'll review the PR thoroughly when I can since this is an important upgrade.

@kashtennyson
Copy link
Copy Markdown
Author

Alright @dpascualhe. Thanks for the update!

I am also currently working on a broader refactor to provide global .pt support across all tasks (Detection, Segmentation, and LiDAR) by centralizing the loading and normalization logic into a shared BaseTorchModel utility. So, your guidance and feedback is crucial for these architectural decisions. Looking forward to your thoughts!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants