Skip to content

[feat]: support offload for flux2#1034

Open
wangshankun wants to merge 1 commit intomainfrom
dev/flux2_offload
Open

[feat]: support offload for flux2#1034
wangshankun wants to merge 1 commit intomainfrom
dev/flux2_offload

Conversation

@wangshankun
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces CPU offloading and lazy loading capabilities for the Flux2 model to optimize memory usage. Key changes include a new offload configuration, fallback mechanisms for pinned memory allocation in utils.py, and logic in flux2_runner.py to dynamically load and unload text encoders and VAE modules during inference. Feedback focuses on improving the robustness of device module retrieval, narrowing exception handling when allocating pinned memory, replacing assertions with explicit value errors for configuration validation, and ensuring consistent attribute deletion when unloading modules.

from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using getattr(torch, AI_DEVICE) is risky. If AI_DEVICE is a device string like "cuda:0" or "cpu", this will raise an AttributeError. Typically, AI_DEVICE refers to the device identifier used with torch.device(), while getattr expects a module name like "cuda" or "mps". Additionally, torch does not have a cpu attribute that acts as a device module. Consider extracting the device type (e.g., AI_DEVICE.split(':')[0]) and handling the "cpu" case explicitly to avoid a crash.

pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
try:
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
except Exception as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching the generic Exception class is too broad and can mask unrelated errors. For PyTorch memory allocation failures, it is better to catch RuntimeError specifically, as that is what torch.empty typically raises when pinned memory allocation fails.

Suggested change
except Exception as e:
except RuntimeError as e:

self.load_model()
self.model.set_scheduler(self.scheduler)
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for runtime configuration validation is discouraged because assertions can be disabled in optimized Python execution (using the -O flag). It is better to raise a ValueError to ensure the check is always performed.

Suggested change
assert self.config.get("cpu_offload", False)
if not self.config.get("cpu_offload", False):
raise ValueError("cpu_offload must be enabled when lazy_load is true")

text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
torch.cuda.empty_cache()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Deleting only the first element of the list leaves self.text_encoders as an empty list []. It is cleaner and more consistent with how self.vae is handled (line 281) to delete the entire attribute, which also avoids potential IndexError if the list is accessed elsewhere while empty.

Suggested change
del self.text_encoders[0]
del self.text_encoders

self.text_encoders = self.load_text_encoder()
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Deleting only the first element of the list leaves self.text_encoders as an empty list []. It is cleaner and more consistent with how self.vae is handled (line 281) to delete the entire attribute.

Suggested change
del self.text_encoders[0]
del self.text_encoders

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant