Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces CPU offloading and lazy loading capabilities for the Flux2 model to optimize memory usage. Key changes include a new offload configuration, fallback mechanisms for pinned memory allocation in utils.py, and logic in flux2_runner.py to dynamically load and unload text encoders and VAE modules during inference. Feedback focuses on improving the robustness of device module retrieval, narrowing exception handling when allocating pinned memory, replacing assertions with explicit value errors for configuration validation, and ensuring consistent attribute deletion when unloading modules.
| from lightx2v.utils.registry_factory import RUNNER_REGISTER | ||
| from lightx2v_platform.base.global_var import AI_DEVICE | ||
|
|
||
| torch_device_module = getattr(torch, AI_DEVICE) |
There was a problem hiding this comment.
Using getattr(torch, AI_DEVICE) is risky. If AI_DEVICE is a device string like "cuda:0" or "cpu", this will raise an AttributeError. Typically, AI_DEVICE refers to the device identifier used with torch.device(), while getattr expects a module name like "cuda" or "mps". Additionally, torch does not have a cpu attribute that acts as a device module. Consider extracting the device type (e.g., AI_DEVICE.split(':')[0]) and handling the "cpu" case explicitly to avoid a crash.
| pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype) | ||
| try: | ||
| pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype) | ||
| except Exception as e: |
There was a problem hiding this comment.
Catching the generic Exception class is too broad and can mask unrelated errors. For PyTorch memory allocation failures, it is better to catch RuntimeError specifically, as that is what torch.empty typically raises when pinned memory allocation fails.
| except Exception as e: | |
| except RuntimeError as e: |
| self.load_model() | ||
| self.model.set_scheduler(self.scheduler) | ||
| elif self.config.get("lazy_load", False): | ||
| assert self.config.get("cpu_offload", False) |
There was a problem hiding this comment.
Using assert for runtime configuration validation is discouraged because assertions can be disabled in optimized Python execution (using the -O flag). It is better to raise a ValueError to ensure the check is always performed.
| assert self.config.get("cpu_offload", False) | |
| if not self.config.get("cpu_offload", False): | |
| raise ValueError("cpu_offload must be enabled when lazy_load is true") |
| text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) | ||
| torch.cuda.empty_cache() | ||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||
| del self.text_encoders[0] |
There was a problem hiding this comment.
Deleting only the first element of the list leaves self.text_encoders as an empty list []. It is cleaner and more consistent with how self.vae is handled (line 281) to delete the entire attribute, which also avoids potential IndexError if the list is accessed elsewhere while empty.
| del self.text_encoders[0] | |
| del self.text_encoders |
| self.text_encoders = self.load_text_encoder() | ||
| text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt) | ||
| if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): | ||
| del self.text_encoders[0] |
There was a problem hiding this comment.
No description provided.