-
Notifications
You must be signed in to change notification settings - Fork 320
fix(oom): multimodal OOM fix with auto-profile #1273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fd2169b
37bbe2f
c8367d1
0daaf0a
2331016
a2d2cd3
84ece07
60c0509
d6b498b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |||||||||||||||||||||||||||||||||||||||||
| import gc | ||||||||||||||||||||||||||||||||||||||||||
| import copy | ||||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||||
| import triton | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -134,7 +135,8 @@ def __init__(self, kvargs): | |||||||||||||||||||||||||||||||||||||||||
| self._init_cudagraph() | ||||||||||||||||||||||||||||||||||||||||||
| self._init_prefill_cuda_graph() | ||||||||||||||||||||||||||||||||||||||||||
| self._check_max_len_infer() | ||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||||
| self._check_decode_infer() | ||||||||||||||||||||||||||||||||||||||||||
| self._auto_profile_rebuild_and_validate() | ||||||||||||||||||||||||||||||||||||||||||
| set_model_init_status(True) | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -212,7 +214,12 @@ def _init_kv_move_buffer(self): | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _check_mem_size(self): | ||||||||||||||||||||||||||||||||||||||||||
| self.max_total_token_num = self.mem_manager.size | ||||||||||||||||||||||||||||||||||||||||||
| assert self.max_seq_length <= self.max_total_token_num | ||||||||||||||||||||||||||||||||||||||||||
| # Skip the max_seq_length assertion during the auto-profile probe | ||||||||||||||||||||||||||||||||||||||||||
| # phase: the probe KV is intentionally small (just enough for the | ||||||||||||||||||||||||||||||||||||||||||
| # stress forwards), not sized to hold a full-length request. The | ||||||||||||||||||||||||||||||||||||||||||
| # assertion is re-checked after Phase 3 rebuilds at the target size. | ||||||||||||||||||||||||||||||||||||||||||
| if getattr(self.mem_manager, "_probe_tokens", None) is None or self.mem_manager.size >= self.max_seq_length: | ||||||||||||||||||||||||||||||||||||||||||
| assert self.max_seq_length <= self.max_total_token_num | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _init_req_manager(self): | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -944,11 +951,6 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: | |||||||||||||||||||||||||||||||||||||||||
| @final | ||||||||||||||||||||||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||
| def _check_max_len_infer(self): | ||||||||||||||||||||||||||||||||||||||||||
| disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None | ||||||||||||||||||||||||||||||||||||||||||
| if disable_check_max_len_infer: | ||||||||||||||||||||||||||||||||||||||||||
| logger.info("disable_check_max_len_infer is true") | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # 做一次 同步 | ||||||||||||||||||||||||||||||||||||||||||
| torch.distributed.barrier() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1003,6 +1005,279 @@ def _check_max_len_infer(self): | |||||||||||||||||||||||||||||||||||||||||
| raise Exception(exception_str) | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||
| def _allocate_decode_stress_slots(self, batch_size): | ||||||||||||||||||||||||||||||||||||||||||
| """Allocate req/mem slots for the decode stress forward. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Override this in mamba-aware subclasses (e.g. qwen3next) to also | ||||||||||||||||||||||||||||||||||||||||||
| allocate mamba buffers via req_manager.alloc_buffer_for_req so the | ||||||||||||||||||||||||||||||||||||||||||
| GDN layers exercise their full memory footprint. The base version | ||||||||||||||||||||||||||||||||||||||||||
| only touches the standard req/mem managers. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Returns a tuple (actual_batch, req_idxs, tokens_per_req, total_tokens, | ||||||||||||||||||||||||||||||||||||||||||
| mem_indexes) that _build_decode_model_input consumes. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| req_idxs = [] | ||||||||||||||||||||||||||||||||||||||||||
| for _ in range(batch_size): | ||||||||||||||||||||||||||||||||||||||||||
| idx = self.req_manager.alloc() | ||||||||||||||||||||||||||||||||||||||||||
| if idx is None: | ||||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||||
| req_idxs.append(idx) | ||||||||||||||||||||||||||||||||||||||||||
| actual_batch = len(req_idxs) | ||||||||||||||||||||||||||||||||||||||||||
| if actual_batch < 2: | ||||||||||||||||||||||||||||||||||||||||||
| return actual_batch, req_idxs, 0, 0, None | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| tokens_per_req = min( | ||||||||||||||||||||||||||||||||||||||||||
| self.batch_max_tokens, | ||||||||||||||||||||||||||||||||||||||||||
| max(1, self.max_total_token_num // actual_batch), | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| total_tokens = tokens_per_req * actual_batch | ||||||||||||||||||||||||||||||||||||||||||
| total_tokens = min(total_tokens, self.mem_manager.can_use_mem_size) | ||||||||||||||||||||||||||||||||||||||||||
| tokens_per_req = max(1, total_tokens // actual_batch) | ||||||||||||||||||||||||||||||||||||||||||
| total_tokens = tokens_per_req * actual_batch | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| mem_indexes = self.mem_manager.alloc(total_tokens).cuda() | ||||||||||||||||||||||||||||||||||||||||||
| return actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _build_decode_model_input(self, actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes): | ||||||||||||||||||||||||||||||||||||||||||
| """Construct a decode-shaped ModelInput for the stress forward. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Split out so mamba subclasses can also override just the ModelInput | ||||||||||||||||||||||||||||||||||||||||||
| shape if needed, without touching _check_decode_infer's control flow. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| dummy_input_ids = torch.ones(actual_batch, dtype=torch.int32, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||
| b_req_idx = torch.tensor(req_idxs, dtype=torch.int32, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||
| b_seq_len = torch.full((actual_batch,), tokens_per_req, dtype=torch.int32, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||
| b_ready_cache_len = torch.zeros(actual_batch, dtype=torch.int32, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||
| b_mtp_index = torch.zeros(actual_batch, dtype=torch.int32, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||
| return ModelInput( | ||||||||||||||||||||||||||||||||||||||||||
| batch_size=actual_batch, | ||||||||||||||||||||||||||||||||||||||||||
| total_token_num=total_tokens, | ||||||||||||||||||||||||||||||||||||||||||
| max_q_seq_len=1, | ||||||||||||||||||||||||||||||||||||||||||
| max_kv_seq_len=tokens_per_req, | ||||||||||||||||||||||||||||||||||||||||||
| max_cache_len=tokens_per_req - 1, | ||||||||||||||||||||||||||||||||||||||||||
| prefix_total_token_num=0, | ||||||||||||||||||||||||||||||||||||||||||
| input_ids=dummy_input_ids, | ||||||||||||||||||||||||||||||||||||||||||
| # mem_indexes[:actual_batch] provides 1 new KV slot per request for the decode | ||||||||||||||||||||||||||||||||||||||||||
| # step's output token. The full total_tokens block was allocated from mem_manager | ||||||||||||||||||||||||||||||||||||||||||
| # to occupy the KV cache space, but only 1 slot per request is the "new" token. | ||||||||||||||||||||||||||||||||||||||||||
| mem_indexes=mem_indexes[:actual_batch], | ||||||||||||||||||||||||||||||||||||||||||
| b_req_idx=b_req_idx, | ||||||||||||||||||||||||||||||||||||||||||
| b_seq_len=b_seq_len, | ||||||||||||||||||||||||||||||||||||||||||
| b_mtp_index=b_mtp_index, | ||||||||||||||||||||||||||||||||||||||||||
| is_prefill=False, | ||||||||||||||||||||||||||||||||||||||||||
| b_ready_cache_len=b_ready_cache_len, | ||||||||||||||||||||||||||||||||||||||||||
| multimodal_params=[{"images": [], "audios": []}] * actual_batch, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||
| def _check_decode_infer(self): | ||||||||||||||||||||||||||||||||||||||||||
| """Simulate a decode batch to detect OOM from concurrent request activations. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Ported from origin/qw35_stable:basemodel.py:895 and split into | ||||||||||||||||||||||||||||||||||||||||||
| overridable sub-methods per spec §6.4 so qwen3next (when it lands on | ||||||||||||||||||||||||||||||||||||||||||
| main) can cleanly hook in mamba buffer allocation. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| The prob_out.sort() call at the end is load-bearing: it forces the | ||||||||||||||||||||||||||||||||||||||||||
| top_p/top_k sampling allocation that real inference triggers but a | ||||||||||||||||||||||||||||||||||||||||||
| naive decode forward does not. On a vocab=152k model with | ||||||||||||||||||||||||||||||||||||||||||
| graph_max_batch_size=64 this accounts for ~20 MB per decode step, so | ||||||||||||||||||||||||||||||||||||||||||
| skipping it under-measures the peak by exactly the amount that causes | ||||||||||||||||||||||||||||||||||||||||||
| the "first decode batch after warmup OOMs" bug class. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| torch.distributed.barrier() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| batch_size = self.graph_max_batch_size | ||||||||||||||||||||||||||||||||||||||||||
| if batch_size <= 1: | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"begin check decode infer with batch_size={batch_size}") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes = self._allocate_decode_stress_slots( | ||||||||||||||||||||||||||||||||||||||||||
| batch_size | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| if actual_batch < 2: | ||||||||||||||||||||||||||||||||||||||||||
| logger.info("skip decode check: not enough req slots") | ||||||||||||||||||||||||||||||||||||||||||
| self.req_manager.free_all() | ||||||||||||||||||||||||||||||||||||||||||
| self.mem_manager.free_all() | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| model_input = self._build_decode_model_input( | ||||||||||||||||||||||||||||||||||||||||||
| actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| model_output = self.forward(model_input) | ||||||||||||||||||||||||||||||||||||||||||
| prob_out = torch.softmax(model_output.logits, dim=-1) | ||||||||||||||||||||||||||||||||||||||||||
| del model_output | ||||||||||||||||||||||||||||||||||||||||||
| # Force top_p/top_k sampling allocation — load-bearing for peak measurement. | ||||||||||||||||||||||||||||||||||||||||||
| prob_out.sort(dim=-1, descending=True) | ||||||||||||||||||||||||||||||||||||||||||
| prob_out = None | ||||||||||||||||||||||||||||||||||||||||||
| self.req_manager.free_all() | ||||||||||||||||||||||||||||||||||||||||||
| self.mem_manager.free_all() | ||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"check decode {actual_batch} infer ok") | ||||||||||||||||||||||||||||||||||||||||||
| except (RuntimeError, torch.OutOfMemoryError) as e: | ||||||||||||||||||||||||||||||||||||||||||
| logger.exception(str(e)) | ||||||||||||||||||||||||||||||||||||||||||
| exception_str = ( | ||||||||||||||||||||||||||||||||||||||||||
| "check decode infer fail, you can try:\n" | ||||||||||||||||||||||||||||||||||||||||||
| "1. Set --graph_max_batch_size to a smaller value.\n" | ||||||||||||||||||||||||||||||||||||||||||
| "2. Set --mem_fraction or --max_total_token_num to a smaller value.\n" | ||||||||||||||||||||||||||||||||||||||||||
| "3. Set --max_req_total_len to a smaller value." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| logger.error(exception_str) | ||||||||||||||||||||||||||||||||||||||||||
| raise Exception(exception_str) | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _teardown_graphs_and_kv(self): | ||||||||||||||||||||||||||||||||||||||||||
| """Phase 3 teardown: drop references to the probe's kv_buffer and | ||||||||||||||||||||||||||||||||||||||||||
| all captured CUDA graphs so torch.cuda.empty_cache() can actually | ||||||||||||||||||||||||||||||||||||||||||
| return the blocks to the driver. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Order matters — graphs hold tensor pointers into kv_buffer and must | ||||||||||||||||||||||||||||||||||||||||||
| go first, then req_manager (which references mem_manager), then | ||||||||||||||||||||||||||||||||||||||||||
| mem_manager itself. See spec §6.5 for the rationale. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, "mem_manager") and self.mem_manager is not None: | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| self.mem_manager.free_all() | ||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for attr in ("graph", "prefill_graph", "prefill_cuda_graph"): | ||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, attr): | ||||||||||||||||||||||||||||||||||||||||||
| setattr(self, attr, None) | ||||||||||||||||||||||||||||||||||||||||||
| # MTP variants (if present) | ||||||||||||||||||||||||||||||||||||||||||
| for attr in ("graph1", "prefill_graph1"): | ||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, attr): | ||||||||||||||||||||||||||||||||||||||||||
| setattr(self, attr, None) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, "req_manager"): | ||||||||||||||||||||||||||||||||||||||||||
| self.req_manager = None | ||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, "mem_manager"): | ||||||||||||||||||||||||||||||||||||||||||
| self.mem_manager = None | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _auto_profile_rebuild_and_validate(self): | ||||||||||||||||||||||||||||||||||||||||||
| """Phases 2-4 of the auto-profile loop. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Runs after Phase 1 (__init__ through _check_decode_infer). Measures | ||||||||||||||||||||||||||||||||||||||||||
| the probe's peak, computes the target KV size, tears down the probe's | ||||||||||||||||||||||||||||||||||||||||||
| graphs and mem_manager, calls torch.cuda.empty_cache() once, re-inits | ||||||||||||||||||||||||||||||||||||||||||
| everything at the target size, re-captures graphs, allocates the 256 MB | ||||||||||||||||||||||||||||||||||||||||||
| canary, and validates by re-running the stress forwards. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| On validation OOM, shrinks target_tokens by 5% and loops. Retry budget: | ||||||||||||||||||||||||||||||||||||||||||
| 3 retries (4 total attempts). After that, raises with a multi-knob | ||||||||||||||||||||||||||||||||||||||||||
| diagnostic (see spec §7.2). | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| If --max_total_token_num was set explicitly (probe path was skipped), | ||||||||||||||||||||||||||||||||||||||||||
| this method is a pure no-op beyond a single empty_cache call to match | ||||||||||||||||||||||||||||||||||||||||||
| the pre-auto-profile behavior. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| if self.mem_manager._probe_tokens is None: | ||||||||||||||||||||||||||||||||||||||||||
| logger.info("auto-profile phase=skip reason=explicit_max_total_token_num") | ||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| peak_reserved = torch.cuda.max_memory_reserved() | ||||||||||||||||||||||||||||||||||||||||||
| initial_target_tokens = self.mem_manager.profile_size_target(peak_reserved) | ||||||||||||||||||||||||||||||||||||||||||
| target_tokens = initial_target_tokens | ||||||||||||||||||||||||||||||||||||||||||
| probe_tokens = self.mem_manager._probe_tokens | ||||||||||||||||||||||||||||||||||||||||||
| probe_kv_bytes = probe_tokens * self.mem_manager.get_cell_size() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| RETRY_BUDGET = 3 | ||||||||||||||||||||||||||||||||||||||||||
| SHRINK_RATIO = 0.95 | ||||||||||||||||||||||||||||||||||||||||||
| CANARY_BYTES = 256 * 1024 * 1024 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| attempt = 0 | ||||||||||||||||||||||||||||||||||||||||||
| last_exc = None | ||||||||||||||||||||||||||||||||||||||||||
| while attempt <= RETRY_BUDGET: | ||||||||||||||||||||||||||||||||||||||||||
| attempt += 1 | ||||||||||||||||||||||||||||||||||||||||||
| t0 = time.time() | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| # Phase 3: tear down probe graphs and mem_manager | ||||||||||||||||||||||||||||||||||||||||||
| self._teardown_graphs_and_kv() | ||||||||||||||||||||||||||||||||||||||||||
| reserved_before_empty = torch.cuda.memory_reserved() | ||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||||||||
| reserved_after_empty = torch.cuda.memory_reserved() | ||||||||||||||||||||||||||||||||||||||||||
| released = reserved_before_empty - reserved_after_empty | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Sanity: we should have released at least the probe kv_buffer. | ||||||||||||||||||||||||||||||||||||||||||
| # The threshold is probe_kv_bytes (strict, per spec §6.5 step 6) | ||||||||||||||||||||||||||||||||||||||||||
| # because the probe's kv_buffer is a single contiguous allocation | ||||||||||||||||||||||||||||||||||||||||||
| # and PyTorch's caching allocator returns whole segments on empty_cache. | ||||||||||||||||||||||||||||||||||||||||||
| if attempt == 1 and released < probe_kv_bytes: | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||
| f"auto-profile phase=rebuild TEARDOWN LEAK: " | ||||||||||||||||||||||||||||||||||||||||||
| f"empty_cache() only released {released / 1024 ** 3:.2f} GB " | ||||||||||||||||||||||||||||||||||||||||||
| f"but probe kv_buffer alone is {probe_kv_bytes / 1024 ** 3:.2f} GB. " | ||||||||||||||||||||||||||||||||||||||||||
| f"Some Python reference to the probe kv_buffer or a captured " | ||||||||||||||||||||||||||||||||||||||||||
| f"CUDA graph was not dropped by _teardown_graphs_and_kv. " | ||||||||||||||||||||||||||||||||||||||||||
| f"Investigate which attribute is leaking." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Phase 3: re-init everything at target_tokens | ||||||||||||||||||||||||||||||||||||||||||
| self.max_total_token_num = target_tokens | ||||||||||||||||||||||||||||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||||||||||||||||||||||||||||
| self._init_kv_move_buffer() | ||||||||||||||||||||||||||||||||||||||||||
| self._check_mem_size() | ||||||||||||||||||||||||||||||||||||||||||
| self._init_req_manager() | ||||||||||||||||||||||||||||||||||||||||||
| self._init_cudagraph() | ||||||||||||||||||||||||||||||||||||||||||
| self._init_prefill_cuda_graph() | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1217
to
+1224
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rebuild loop in Phase 3 is missing calls to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Canary | ||||||||||||||||||||||||||||||||||||||||||
| self._oom_canary = torch.empty(CANARY_BYTES, dtype=torch.uint8, device="cuda") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||
| f"auto-profile phase=rebuild attempt={attempt} " | ||||||||||||||||||||||||||||||||||||||||||
| f"elapsed_sec={time.time() - t0:.2f} " | ||||||||||||||||||||||||||||||||||||||||||
| f"new_kv_tokens={target_tokens}" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Phase 4: validate | ||||||||||||||||||||||||||||||||||||||||||
| t1 = time.time() | ||||||||||||||||||||||||||||||||||||||||||
| self._check_max_len_infer() | ||||||||||||||||||||||||||||||||||||||||||
| self._check_decode_infer() | ||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||
| f"auto-profile phase=validate attempt={attempt} " f"elapsed_sec={time.time() - t1:.2f} result=ok" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| return # success | ||||||||||||||||||||||||||||||||||||||||||
| except (RuntimeError, torch.cuda.OutOfMemoryError, torch.OutOfMemoryError) as e: | ||||||||||||||||||||||||||||||||||||||||||
| last_exc = e | ||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||
| f"auto-profile phase=validate attempt={attempt} " | ||||||||||||||||||||||||||||||||||||||||||
| f"result={'retry' if attempt <= RETRY_BUDGET else 'fail'} " | ||||||||||||||||||||||||||||||||||||||||||
| f"error={type(e).__name__}: {e}" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| if attempt > RETRY_BUDGET: | ||||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||||
| target_tokens = int(target_tokens * SHRINK_RATIO) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # All retries exhausted — raise a multi-knob diagnostic | ||||||||||||||||||||||||||||||||||||||||||
| total_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 | ||||||||||||||||||||||||||||||||||||||||||
| cell_size = None | ||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| cell_size = self.mem_manager.get_cell_size() | ||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||||||
| initial_gb = initial_target_tokens * (cell_size or 0) / 1024 ** 3 | ||||||||||||||||||||||||||||||||||||||||||
| final_gb = target_tokens * (cell_size or 0) / 1024 ** 3 | ||||||||||||||||||||||||||||||||||||||||||
| raise Exception( | ||||||||||||||||||||||||||||||||||||||||||
| f"Auto-profile failed after {attempt} attempts.\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"Initial target: {initial_target_tokens} tokens ({initial_gb:.2f} GB KV)\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"Final attempted: {target_tokens} tokens ({final_gb:.2f} GB KV)\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"Measured peak: {peak_reserved / 1024 ** 3:.2f} GB\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"Total GPU memory: {total_memory_gb:.2f} GB\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"Canary reserve: {CANARY_BYTES / 1024 ** 3:.2f} GB\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"The configured load does not fit on this device. Try:\n" | ||||||||||||||||||||||||||||||||||||||||||
| f" 1. --batch_max_tokens: reduce to lower prefill activation peak\n" | ||||||||||||||||||||||||||||||||||||||||||
| f" 2. --graph_max_batch_size: reduce to lower decode activation peak\n" | ||||||||||||||||||||||||||||||||||||||||||
| f" 3. --visual_infer_batch_size: reduce to lower ViT pinned footprint\n" | ||||||||||||||||||||||||||||||||||||||||||
| f" 4. --max_total_token_num: pin a specific KV size (skips auto-profile)\n" | ||||||||||||||||||||||||||||||||||||||||||
| f" 5. --mem_fraction: as a last resort, set < 1.0 to add extra safety margin\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"\n" | ||||||||||||||||||||||||||||||||||||||||||
| f"Last error: {type(last_exc).__name__}: {last_exc}" | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def autotune_layers(self): | ||||||||||||||||||||||||||||||||||||||||||
| # 控制autotune的层数,用于适配不同模型 | ||||||||||||||||||||||||||||||||||||||||||
| return self.config.get("first_k_dense_replace", 0) + 1 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_teardown_graphs_and_kvmethod is missing the cleanup ofself.att_backend(andself.att_backend1if present) andself._oom_canary. If these references are not dropped,torch.cuda.empty_cache()will not be able to release the memory occupied by the probe KV buffer and the canary tensor, which will cause the teardown leak check to fail or lead to OOM during the rebuild phase.