security(opt): enable weights_only=True by default for secure checkpoint loading#1056
security(opt): enable weights_only=True by default for secure checkpoint loading#1056RinZ27 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a safe deserialization helper Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/torch/opt/conversion.py (2)
630-632: Consider adding a security comment for consistency.The
weights_only=Truedefault is correct and aligns with the security guidelines. For consistency withload_modelopt_state(line 526), consider adding a brief inline comment explaining this is a security measure for ModelOpt-generated checkpoints.📝 Optional: Add security comment for consistency
# load checkpoint kwargs.setdefault("map_location", "cpu") + # Security NOTE: weights_only=True is used here on ModelOpt-generated checkpoints kwargs.setdefault("weights_only", True) objs = torch.load(f, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/opt/conversion.py` around lines 630 - 632, Add a brief inline security comment next to the kwargs.setdefault("weights_only", True) in conversion.py (near the torch.load call) explaining that setting weights_only=True is a security measure for loading ModelOpt-generated checkpoints, mirroring the comment style used in load_modelopt_state to make intent consistent and clear.
553-553: Documentation example could reinforce secure loading pattern.The example shows
torch.load("model_weights.pt")withoutweights_only=True. While PyTorch >= 2.6 defaults toTrue, explicitly showing the secure pattern in documentation helps reinforce best practices for users on older PyTorch versions.📝 Optional: Update docstring example
- model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights + model.load_state_dict(torch.load("model_weights.pt", weights_only=True), ...) # Load the model weights🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/opt/conversion.py` at line 553, Update the docstring/example to explicitly use the secure loading pattern by passing weights_only=True to torch.load in the example call (i.e., change the call shown next to model.load_state_dict to use torch.load("model_weights.pt", weights_only=True)); reference the model.load_state_dict and torch.load usage in conversion.py so readers see the explicit secure flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 630-632: Add a brief inline security comment next to the
kwargs.setdefault("weights_only", True) in conversion.py (near the torch.load
call) explaining that setting weights_only=True is a security measure for
loading ModelOpt-generated checkpoints, mirroring the comment style used in
load_modelopt_state to make intent consistent and clear.
- Line 553: Update the docstring/example to explicitly use the secure loading
pattern by passing weights_only=True to torch.load in the example call (i.e.,
change the call shown next to model.load_state_dict to use
torch.load("model_weights.pt", weights_only=True)); reference the
model.load_state_dict and torch.load usage in conversion.py so readers see the
explicit secure flag.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a5f245a3-81e0-4c09-8930-881be6f870b2
📥 Commits
Reviewing files that changed from the base of the PR and between 00fa5bd and f54922151bc7dec8a49bcf9791f8a65e7dc7b966.
📒 Files selected for processing (1)
modelopt/torch/opt/conversion.py
|
Expanded this PR to cover the additional locations and actions suggested in the triage analysis (#1055). Changes include:
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (1)
112-135:⚠️ Potential issue | 🔴 CriticalMetadata for sharded checkpoint format is not persisted or reconstructed.
The save path stores only
modelopt_statewith no metadata persistence. However, the restore path expects ametadataparameter (defaulting toNone) and_load_extra_state_from_sharded_checkpoint()depends on version-sensitive fields likesingleton_local_shards(introduced in megatron-core 0.15.0 as a breaking change).All call sites invoke
restore_sharded_modelopt_state()without passingmetadata, so it defaults toNone. With no metadata persisted or reconstructed, the restore path cannot correctly interpret the sharded state_dict format across different megatron-core versions. Either:
- Persist metadata alongside
modelopt_stateduring save, or- Reconstruct metadata from the environment/current model state during restore
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/opt/plugins/mcore_dist_checkpointing.py` around lines 112 - 135, save_sharded_modelopt_state currently writes only modelopt_state (via dist_checkpointing.save) so restore_sharded_modelopt_state and _load_extra_state_from_sharded_checkpoint lack the version/format metadata (e.g., singleton_local_shards) needed to interpret sharded layout across megatron-core versions; fix by persisting a metadata dict alongside the saved modelopt_state (include at minimum megatron-core version, sharded_strategy details, and singleton_local_shards boolean) when save_sharded_modelopt_state calls dist_checkpointing.save, and update restore_sharded_modelopt_state/_load_extra_state_from_sharded_checkpoint to read that metadata if present (fall back to reconstructing it from the runtime environment/model when metadata is absent) so all call sites that pass no metadata still restore correctly.modelopt/torch/opt/plugins/megatron.py (1)
105-113:⚠️ Potential issue | 🟠 MajorAdd
map_location="cpu"to the nested deserialization to ensure all tensors are restored to CPU.The serialized
extra_statedict may contain tensors that are recreated on their original device during deserialization. Withoutmap_location, this can fail on CPU-only restores or cause unexpected device placement.Suggested fix
- extra_state = safe_load(state.detach().cpu().numpy().tobytes()) + extra_state = safe_load( + state.detach().cpu().numpy().tobytes(), + map_location="cpu", + )As per coding guidelines
SECURITY.md: checkpoint/state loading should usemap_location="cpu"for safe deserialization.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/opt/plugins/megatron.py` around lines 105 - 113, The nested deserialization of extra_state uses safe_load(state.detach().cpu().numpy().tobytes()) which can recreate tensors on their original device; update the call to pass map_location="cpu" so all tensors in extra_state are restored to CPU (i.e., call safe_load(..., map_location="cpu")) — edit the deserialization in megatron.py where extra_state is assigned and ensure safe_load includes the map_location argument to enforce CPU placement.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 517-528: The load_modelopt_state function currently calls
safe_load(modelopt_state_path, **kwargs) without a map_location; update
load_modelopt_state to pass map_location="cpu" to safe_load by default (but
allow callers to override via kwargs) so its behavior matches restore; modify
the safe_load invocation in load_modelopt_state to set
kwargs.setdefault("map_location", "cpu") before calling safe_load.
---
Outside diff comments:
In `@modelopt/torch/opt/plugins/mcore_dist_checkpointing.py`:
- Around line 112-135: save_sharded_modelopt_state currently writes only
modelopt_state (via dist_checkpointing.save) so restore_sharded_modelopt_state
and _load_extra_state_from_sharded_checkpoint lack the version/format metadata
(e.g., singleton_local_shards) needed to interpret sharded layout across
megatron-core versions; fix by persisting a metadata dict alongside the saved
modelopt_state (include at minimum megatron-core version, sharded_strategy
details, and singleton_local_shards boolean) when save_sharded_modelopt_state
calls dist_checkpointing.save, and update
restore_sharded_modelopt_state/_load_extra_state_from_sharded_checkpoint to read
that metadata if present (fall back to reconstructing it from the runtime
environment/model when metadata is absent) so all call sites that pass no
metadata still restore correctly.
In `@modelopt/torch/opt/plugins/megatron.py`:
- Around line 105-113: The nested deserialization of extra_state uses
safe_load(state.detach().cpu().numpy().tobytes()) which can recreate tensors on
their original device; update the call to pass map_location="cpu" so all tensors
in extra_state are restored to CPU (i.e., call safe_load(...,
map_location="cpu")) — edit the deserialization in megatron.py where extra_state
is assigned and ensure safe_load includes the map_location argument to enforce
CPU placement.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: afafdd41-cbaf-43c3-9d96-8285b89209ec
📥 Commits
Reviewing files that changed from the base of the PR and between f54922151bc7dec8a49bcf9791f8a65e7dc7b966 and 771ebade5858963ed84643632e8dee7c97c99e87.
📒 Files selected for processing (6)
modelopt/torch/export/distribute.pymodelopt/torch/opt/conversion.pymodelopt/torch/opt/plugins/mcore_dist_checkpointing.pymodelopt/torch/opt/plugins/megatron.pymodelopt/torch/utils/__init__.pymodelopt/torch/utils/serialization.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/utils/init.py
|
You need to sign your commits with an SSH key for CICD to run. Please take a look at https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work |
e97db79 to
b33a191
Compare
b6d6b9e to
058f1fc
Compare
kevalmorabia97
left a comment
There was a problem hiding this comment.
Great work! Left some comments. Please update CHANGELOG.rst under 0.44 section regarding this potentially backward breaking change as well
058f1fc to
0c8f39c
Compare
da4655a to
1d3baf8
Compare
|
|
||
| print_rank_0(f"Loading searcher state from {checkpoint}...") | ||
| # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input | ||
| state_dict = torch.load(checkpoint, weights_only=False) |
| # ) # bandit throws error here | ||
| # quant_class = model_module.__dict__[new_class_name] | ||
|
|
||
| # Security NOTE: compile() is used here on internally-generated AST, |
There was a problem hiding this comment.
this is not fixed. Could you include a fix for this or re-add the security note?
8024a2b to
28ce7c2
Compare
|
/ok to test 28ce7c2 |
28ce7c2 to
98c30a3
Compare
|
/ok to test 98c30a3 |
|
Failing unit tests |
7280600 to
04f2b88
Compare
04f2b88 to
9b6ae7e
Compare
…int loading Signed-off-by: RinZ27 <222222878+RinZ27@users.noreply.github.com>
9b6ae7e to
a8d544b
Compare
What does this PR do?
Type of change: Bug fix (Security)
Following the triage analysis in #1055, I've expanded the security hardening to cover all checkpoint loading operations across the library and its plugins.
Key updates:
modelopt.torch.utils.serialization.safe_loadwhich enforcesweights_only=Trueby default.ModeloptBaseConfigsubclasses in__init_subclass__to support PyTorch's restricted loading mode without manual boilerplate.pickle.loadsin the Megatron plugin withsafe_loadand switched totorch.savefor internal state serialization.These changes provide a robust defense against RCE from malicious state files while maintaining full backward compatibility.
Before your PR is "Ready for review"
CONTRIBUTING.md: N/AAdditional Information
Related to issue #1055.