Skip to content

security(opt): enable weights_only=True by default for secure checkpoint loading#1056

Open
RinZ27 wants to merge 1 commit intoNVIDIA:mainfrom
RinZ27:feature/secure-state-loading
Open

security(opt): enable weights_only=True by default for secure checkpoint loading#1056
RinZ27 wants to merge 1 commit intoNVIDIA:mainfrom
RinZ27:feature/secure-state-loading

Conversation

@RinZ27
Copy link
Copy Markdown

@RinZ27 RinZ27 commented Mar 17, 2026

What does this PR do?

Type of change: Bug fix (Security)

Following the triage analysis in #1055, I've expanded the security hardening to cover all checkpoint loading operations across the library and its plugins.

Key updates:

  • Centralized Safe Loading: Introduced modelopt.torch.utils.serialization.safe_load which enforces weights_only=True by default.
  • Automatic Safe Globals Registration: Implemented automatic registration of ModeloptBaseConfig subclasses in __init_subclass__ to support PyTorch's restricted loading mode without manual boilerplate.
  • Plugin Support: Updated Megatron, Distributed, and MCore dist checkpointing plugins to use the secure loading helper.
  • Removed Unsafe Serialization: Replaced pickle.loads in the Megatron plugin with safe_load and switched to torch.save for internal state serialization.
  • Cleaned Up: Removed legacy "Security NOTE" comments and simplified imports across all modified modules.

These changes provide a robust defense against RCE from malicious state files while maintaining full backward compatibility.

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅ (Added unit tests for serialization helpers)
  • Did you update Changelog?: ❌ (Will update if required)

Additional Information

Related to issue #1055.

@RinZ27 RinZ27 requested a review from a team as a code owner March 17, 2026 07:00
@RinZ27 RinZ27 requested a review from realAsma March 17, 2026 07:00
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a safe deserialization helper safe_load (and registration of ModelOpt types) and replaces direct torch.load(..., weights_only=False) and pickle-based serialization across multiple modules; re-exports the new helpers from modelopt.torch.utils.

Changes

Cohort / File(s) Summary
Serialization helper
modelopt/torch/utils/serialization.py
New module: add_modelopt_safe_globals() registers ModelOpt classes as safe serialization globals; safe_load(f, **kwargs) wraps bytes/streams/paths and defaults weights_only=True before delegating to torch.load.
Utils exports
modelopt/torch/utils/__init__.py
Re-exported serialization helpers via from .serialization import *.
Checkpoint conversion
modelopt/torch/opt/conversion.py
Replaced torch.load(...) with safe_load(...) in load_modelopt_state(...) and restore(...); removed explicit kwargs.setdefault("weights_only", False) usages and updated in-code comments and example to reflect weights_only=True for ModelOpt artifacts.
Distributed export/restore
modelopt/torch/export/distribute.py
Replaced multiple torch.load(...) calls (NFS and shared-memory buffer paths) with safe_load(...); updated the adjacent security note to reference safe_load/weights_only=True.
MCore checkpointing plugin
modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
Removed YAML run-config persistence and related parsing; added mcore_metadata into saved modelopt_state; replaced common replicated-state loading from torch.load(..., weights_only=False) to safe_load(..., map_location="cpu") and extract mcore_metadata when present.
Megatron plugin
modelopt/torch/opt/plugins/megatron.py
Replaced pickle-based extra-state serialization: now torch.save(extra_state, BytesIO()) -> store bytes as torch.Tensor; deserialization uses safe_load(...) on tensor bytes; removed pickle import and adjusted imports/comments.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR successfully implements centralized safe loading with weights_only=True by default, replacing all direct torch.load() calls with safe_load() throughout the modified files without security anti-patterns.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'security(opt): enable weights_only=True by default for secure checkpoint loading' clearly and specifically summarizes the main security-focused change: enabling secure checkpoint loading by defaulting to weights_only=True.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/opt/conversion.py (2)

630-632: Consider adding a security comment for consistency.

The weights_only=True default is correct and aligns with the security guidelines. For consistency with load_modelopt_state (line 526), consider adding a brief inline comment explaining this is a security measure for ModelOpt-generated checkpoints.

📝 Optional: Add security comment for consistency
     # load checkpoint
     kwargs.setdefault("map_location", "cpu")
+    # Security NOTE: weights_only=True is used here on ModelOpt-generated checkpoints
     kwargs.setdefault("weights_only", True)
     objs = torch.load(f, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/conversion.py` around lines 630 - 632, Add a brief inline
security comment next to the kwargs.setdefault("weights_only", True) in
conversion.py (near the torch.load call) explaining that setting
weights_only=True is a security measure for loading ModelOpt-generated
checkpoints, mirroring the comment style used in load_modelopt_state to make
intent consistent and clear.

553-553: Documentation example could reinforce secure loading pattern.

The example shows torch.load("model_weights.pt") without weights_only=True. While PyTorch >= 2.6 defaults to True, explicitly showing the secure pattern in documentation helps reinforce best practices for users on older PyTorch versions.

📝 Optional: Update docstring example
-        model.load_state_dict(torch.load("model_weights.pt"), ...)  # Load the model weights
+        model.load_state_dict(torch.load("model_weights.pt", weights_only=True), ...)  # Load the model weights
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/conversion.py` at line 553, Update the docstring/example
to explicitly use the secure loading pattern by passing weights_only=True to
torch.load in the example call (i.e., change the call shown next to
model.load_state_dict to use torch.load("model_weights.pt", weights_only=True));
reference the model.load_state_dict and torch.load usage in conversion.py so
readers see the explicit secure flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 630-632: Add a brief inline security comment next to the
kwargs.setdefault("weights_only", True) in conversion.py (near the torch.load
call) explaining that setting weights_only=True is a security measure for
loading ModelOpt-generated checkpoints, mirroring the comment style used in
load_modelopt_state to make intent consistent and clear.
- Line 553: Update the docstring/example to explicitly use the secure loading
pattern by passing weights_only=True to torch.load in the example call (i.e.,
change the call shown next to model.load_state_dict to use
torch.load("model_weights.pt", weights_only=True)); reference the
model.load_state_dict and torch.load usage in conversion.py so readers see the
explicit secure flag.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a5f245a3-81e0-4c09-8930-881be6f870b2

📥 Commits

Reviewing files that changed from the base of the PR and between 00fa5bd and f54922151bc7dec8a49bcf9791f8a65e7dc7b966.

📒 Files selected for processing (1)
  • modelopt/torch/opt/conversion.py

@kevalmorabia97 kevalmorabia97 self-requested a review March 17, 2026 07:11
@RinZ27 RinZ27 requested review from a team as code owners March 20, 2026 11:49
@RinZ27 RinZ27 requested a review from meenchen March 20, 2026 11:49
@RinZ27
Copy link
Copy Markdown
Author

RinZ27 commented Mar 20, 2026

Expanded this PR to cover the additional locations and actions suggested in the triage analysis (#1055).

Changes include:

  • Centralized safe loading logic in modelopt.torch.utils.serialization.safe_load.
  • Implemented add_modelopt_safe_globals for internal class support with weights_only=True.
  • Updated Megatron, Distributed, and MCore dist checkpointing plugins.
  • Replaced unsafe pickle.loads with safe_load.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (1)

112-135: ⚠️ Potential issue | 🔴 Critical

Metadata for sharded checkpoint format is not persisted or reconstructed.

The save path stores only modelopt_state with no metadata persistence. However, the restore path expects a metadata parameter (defaulting to None) and _load_extra_state_from_sharded_checkpoint() depends on version-sensitive fields like singleton_local_shards (introduced in megatron-core 0.15.0 as a breaking change).

All call sites invoke restore_sharded_modelopt_state() without passing metadata, so it defaults to None. With no metadata persisted or reconstructed, the restore path cannot correctly interpret the sharded state_dict format across different megatron-core versions. Either:

  1. Persist metadata alongside modelopt_state during save, or
  2. Reconstruct metadata from the environment/current model state during restore
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/plugins/mcore_dist_checkpointing.py` around lines 112 -
135, save_sharded_modelopt_state currently writes only modelopt_state (via
dist_checkpointing.save) so restore_sharded_modelopt_state and
_load_extra_state_from_sharded_checkpoint lack the version/format metadata
(e.g., singleton_local_shards) needed to interpret sharded layout across
megatron-core versions; fix by persisting a metadata dict alongside the saved
modelopt_state (include at minimum megatron-core version, sharded_strategy
details, and singleton_local_shards boolean) when save_sharded_modelopt_state
calls dist_checkpointing.save, and update
restore_sharded_modelopt_state/_load_extra_state_from_sharded_checkpoint to read
that metadata if present (fall back to reconstructing it from the runtime
environment/model when metadata is absent) so all call sites that pass no
metadata still restore correctly.
modelopt/torch/opt/plugins/megatron.py (1)

105-113: ⚠️ Potential issue | 🟠 Major

Add map_location="cpu" to the nested deserialization to ensure all tensors are restored to CPU.

The serialized extra_state dict may contain tensors that are recreated on their original device during deserialization. Without map_location, this can fail on CPU-only restores or cause unexpected device placement.

Suggested fix
-        extra_state = safe_load(state.detach().cpu().numpy().tobytes())
+        extra_state = safe_load(
+            state.detach().cpu().numpy().tobytes(),
+            map_location="cpu",
+        )

As per coding guidelines SECURITY.md: checkpoint/state loading should use map_location="cpu" for safe deserialization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/plugins/megatron.py` around lines 105 - 113, The nested
deserialization of extra_state uses
safe_load(state.detach().cpu().numpy().tobytes()) which can recreate tensors on
their original device; update the call to pass map_location="cpu" so all tensors
in extra_state are restored to CPU (i.e., call safe_load(...,
map_location="cpu")) — edit the deserialization in megatron.py where extra_state
is assigned and ensure safe_load includes the map_location argument to enforce
CPU placement.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 517-528: The load_modelopt_state function currently calls
safe_load(modelopt_state_path, **kwargs) without a map_location; update
load_modelopt_state to pass map_location="cpu" to safe_load by default (but
allow callers to override via kwargs) so its behavior matches restore; modify
the safe_load invocation in load_modelopt_state to set
kwargs.setdefault("map_location", "cpu") before calling safe_load.

---

Outside diff comments:
In `@modelopt/torch/opt/plugins/mcore_dist_checkpointing.py`:
- Around line 112-135: save_sharded_modelopt_state currently writes only
modelopt_state (via dist_checkpointing.save) so restore_sharded_modelopt_state
and _load_extra_state_from_sharded_checkpoint lack the version/format metadata
(e.g., singleton_local_shards) needed to interpret sharded layout across
megatron-core versions; fix by persisting a metadata dict alongside the saved
modelopt_state (include at minimum megatron-core version, sharded_strategy
details, and singleton_local_shards boolean) when save_sharded_modelopt_state
calls dist_checkpointing.save, and update
restore_sharded_modelopt_state/_load_extra_state_from_sharded_checkpoint to read
that metadata if present (fall back to reconstructing it from the runtime
environment/model when metadata is absent) so all call sites that pass no
metadata still restore correctly.

In `@modelopt/torch/opt/plugins/megatron.py`:
- Around line 105-113: The nested deserialization of extra_state uses
safe_load(state.detach().cpu().numpy().tobytes()) which can recreate tensors on
their original device; update the call to pass map_location="cpu" so all tensors
in extra_state are restored to CPU (i.e., call safe_load(...,
map_location="cpu")) — edit the deserialization in megatron.py where extra_state
is assigned and ensure safe_load includes the map_location argument to enforce
CPU placement.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: afafdd41-cbaf-43c3-9d96-8285b89209ec

📥 Commits

Reviewing files that changed from the base of the PR and between f54922151bc7dec8a49bcf9791f8a65e7dc7b966 and 771ebade5858963ed84643632e8dee7c97c99e87.

📒 Files selected for processing (6)
  • modelopt/torch/export/distribute.py
  • modelopt/torch/opt/conversion.py
  • modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
  • modelopt/torch/opt/plugins/megatron.py
  • modelopt/torch/utils/__init__.py
  • modelopt/torch/utils/serialization.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/utils/init.py

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

You need to sign your commits with an SSH key for CICD to run. Please take a look at https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work

@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch 3 times, most recently from e97db79 to b33a191 Compare March 20, 2026 13:13
@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch 2 times, most recently from b6d6b9e to 058f1fc Compare March 23, 2026 14:27
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Left some comments. Please update CHANGELOG.rst under 0.44 section regarding this potentially backward breaking change as well

@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch from 058f1fc to 0c8f39c Compare March 24, 2026 11:45
@RinZ27 RinZ27 requested a review from a team as a code owner March 24, 2026 11:45
@RinZ27 RinZ27 changed the title security(opt): enable weights_only=True by default security(opt): enable weights_only=True by default for secure checkpoint loading Mar 24, 2026
@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch 2 times, most recently from da4655a to 1d3baf8 Compare March 24, 2026 11:50

print_rank_0(f"Loading searcher state from {checkpoint}...")
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state_dict = torch.load(checkpoint, weights_only=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also

# ) # bandit throws error here
# quant_class = model_module.__dict__[new_class_name]

# Security NOTE: compile() is used here on internally-generated AST,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not fixed. Could you include a fix for this or re-add the security note?

@RinZ27 RinZ27 requested a review from ajrasane March 27, 2026 14:15
@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch 4 times, most recently from 8024a2b to 28ce7c2 Compare March 27, 2026 14:50
@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/ok to test 28ce7c2

@kevalmorabia97 kevalmorabia97 requested review from Edwardf0t1 and removed request for a team, ajrasane, meenchen and ynankani March 27, 2026 15:02
@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch from 28ce7c2 to 98c30a3 Compare March 27, 2026 15:43
@RinZ27 RinZ27 closed this Mar 28, 2026
@RinZ27 RinZ27 reopened this Mar 28, 2026
@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/ok to test 98c30a3

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

Failing unit tests

@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch from 7280600 to 04f2b88 Compare March 29, 2026 03:56
@RinZ27 RinZ27 requested a review from a team as a code owner March 29, 2026 03:56
@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch from 04f2b88 to 9b6ae7e Compare March 29, 2026 03:58
…int loading

Signed-off-by: RinZ27 <222222878+RinZ27@users.noreply.github.com>
@RinZ27 RinZ27 force-pushed the feature/secure-state-loading branch from 9b6ae7e to a8d544b Compare March 29, 2026 04:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants