diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c3d89e1c8a..71c043f1a4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -4,6 +4,7 @@ import gc import copy import json +import time import torch import torch.nn.functional as F import triton @@ -134,7 +135,8 @@ def __init__(self, kvargs): self._init_cudagraph() self._init_prefill_cuda_graph() self._check_max_len_infer() - torch.cuda.empty_cache() + self._check_decode_infer() + self._auto_profile_rebuild_and_validate() set_model_init_status(True) return @@ -212,7 +214,12 @@ def _init_kv_move_buffer(self): def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size - assert self.max_seq_length <= self.max_total_token_num + # Skip the max_seq_length assertion during the auto-profile probe + # phase: the probe KV is intentionally small (just enough for the + # stress forwards), not sized to hold a full-length request. The + # assertion is re-checked after Phase 3 rebuilds at the target size. + if getattr(self.mem_manager, "_probe_tokens", None) is None or self.mem_manager.size >= self.max_seq_length: + assert self.max_seq_length <= self.max_total_token_num return def _init_req_manager(self): @@ -944,11 +951,6 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: @final @torch.no_grad() def _check_max_len_infer(self): - disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None - if disable_check_max_len_infer: - logger.info("disable_check_max_len_infer is true") - return - # 做一次 同步 torch.distributed.barrier() @@ -1003,6 +1005,279 @@ def _check_max_len_infer(self): raise Exception(exception_str) return + @torch.no_grad() + def _allocate_decode_stress_slots(self, batch_size): + """Allocate req/mem slots for the decode stress forward. + + Override this in mamba-aware subclasses (e.g. qwen3next) to also + allocate mamba buffers via req_manager.alloc_buffer_for_req so the + GDN layers exercise their full memory footprint. The base version + only touches the standard req/mem managers. + + Returns a tuple (actual_batch, req_idxs, tokens_per_req, total_tokens, + mem_indexes) that _build_decode_model_input consumes. + """ + req_idxs = [] + for _ in range(batch_size): + idx = self.req_manager.alloc() + if idx is None: + break + req_idxs.append(idx) + actual_batch = len(req_idxs) + if actual_batch < 2: + return actual_batch, req_idxs, 0, 0, None + + tokens_per_req = min( + self.batch_max_tokens, + max(1, self.max_total_token_num // actual_batch), + ) + total_tokens = tokens_per_req * actual_batch + total_tokens = min(total_tokens, self.mem_manager.can_use_mem_size) + tokens_per_req = max(1, total_tokens // actual_batch) + total_tokens = tokens_per_req * actual_batch + + mem_indexes = self.mem_manager.alloc(total_tokens).cuda() + return actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes + + def _build_decode_model_input(self, actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes): + """Construct a decode-shaped ModelInput for the stress forward. + + Split out so mamba subclasses can also override just the ModelInput + shape if needed, without touching _check_decode_infer's control flow. + """ + dummy_input_ids = torch.ones(actual_batch, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor(req_idxs, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((actual_batch,), tokens_per_req, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(actual_batch, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(actual_batch, dtype=torch.int32, device="cuda") + return ModelInput( + batch_size=actual_batch, + total_token_num=total_tokens, + max_q_seq_len=1, + max_kv_seq_len=tokens_per_req, + max_cache_len=tokens_per_req - 1, + prefix_total_token_num=0, + input_ids=dummy_input_ids, + # mem_indexes[:actual_batch] provides 1 new KV slot per request for the decode + # step's output token. The full total_tokens block was allocated from mem_manager + # to occupy the KV cache space, but only 1 slot per request is the "new" token. + mem_indexes=mem_indexes[:actual_batch], + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + is_prefill=False, + b_ready_cache_len=b_ready_cache_len, + multimodal_params=[{"images": [], "audios": []}] * actual_batch, + ) + + @torch.no_grad() + def _check_decode_infer(self): + """Simulate a decode batch to detect OOM from concurrent request activations. + + Ported from origin/qw35_stable:basemodel.py:895 and split into + overridable sub-methods per spec §6.4 so qwen3next (when it lands on + main) can cleanly hook in mamba buffer allocation. + + The prob_out.sort() call at the end is load-bearing: it forces the + top_p/top_k sampling allocation that real inference triggers but a + naive decode forward does not. On a vocab=152k model with + graph_max_batch_size=64 this accounts for ~20 MB per decode step, so + skipping it under-measures the peak by exactly the amount that causes + the "first decode batch after warmup OOMs" bug class. + """ + torch.distributed.barrier() + + batch_size = self.graph_max_batch_size + if batch_size <= 1: + return + + try: + logger.info(f"begin check decode infer with batch_size={batch_size}") + + actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes = self._allocate_decode_stress_slots( + batch_size + ) + if actual_batch < 2: + logger.info("skip decode check: not enough req slots") + self.req_manager.free_all() + self.mem_manager.free_all() + return + + model_input = self._build_decode_model_input( + actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes + ) + model_output = self.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + del model_output + # Force top_p/top_k sampling allocation — load-bearing for peak measurement. + prob_out.sort(dim=-1, descending=True) + prob_out = None + self.req_manager.free_all() + self.mem_manager.free_all() + logger.info(f"check decode {actual_batch} infer ok") + except (RuntimeError, torch.OutOfMemoryError) as e: + logger.exception(str(e)) + exception_str = ( + "check decode infer fail, you can try:\n" + "1. Set --graph_max_batch_size to a smaller value.\n" + "2. Set --mem_fraction or --max_total_token_num to a smaller value.\n" + "3. Set --max_req_total_len to a smaller value." + ) + logger.error(exception_str) + raise Exception(exception_str) + return + + def _teardown_graphs_and_kv(self): + """Phase 3 teardown: drop references to the probe's kv_buffer and + all captured CUDA graphs so torch.cuda.empty_cache() can actually + return the blocks to the driver. + + Order matters — graphs hold tensor pointers into kv_buffer and must + go first, then req_manager (which references mem_manager), then + mem_manager itself. See spec §6.5 for the rationale. + """ + if hasattr(self, "mem_manager") and self.mem_manager is not None: + try: + self.mem_manager.free_all() + except Exception: + pass + + for attr in ("graph", "prefill_graph", "prefill_cuda_graph"): + if hasattr(self, attr): + setattr(self, attr, None) + # MTP variants (if present) + for attr in ("graph1", "prefill_graph1"): + if hasattr(self, attr): + setattr(self, attr, None) + + if hasattr(self, "req_manager"): + self.req_manager = None + if hasattr(self, "mem_manager"): + self.mem_manager = None + + def _auto_profile_rebuild_and_validate(self): + """Phases 2-4 of the auto-profile loop. + + Runs after Phase 1 (__init__ through _check_decode_infer). Measures + the probe's peak, computes the target KV size, tears down the probe's + graphs and mem_manager, calls torch.cuda.empty_cache() once, re-inits + everything at the target size, re-captures graphs, allocates the 256 MB + canary, and validates by re-running the stress forwards. + + On validation OOM, shrinks target_tokens by 5% and loops. Retry budget: + 3 retries (4 total attempts). After that, raises with a multi-knob + diagnostic (see spec §7.2). + + If --max_total_token_num was set explicitly (probe path was skipped), + this method is a pure no-op beyond a single empty_cache call to match + the pre-auto-profile behavior. + """ + if self.mem_manager._probe_tokens is None: + logger.info("auto-profile phase=skip reason=explicit_max_total_token_num") + torch.cuda.empty_cache() + return + + peak_reserved = torch.cuda.max_memory_reserved() + initial_target_tokens = self.mem_manager.profile_size_target(peak_reserved) + target_tokens = initial_target_tokens + probe_tokens = self.mem_manager._probe_tokens + probe_kv_bytes = probe_tokens * self.mem_manager.get_cell_size() + + RETRY_BUDGET = 3 + SHRINK_RATIO = 0.95 + CANARY_BYTES = 256 * 1024 * 1024 + + attempt = 0 + last_exc = None + while attempt <= RETRY_BUDGET: + attempt += 1 + t0 = time.time() + try: + # Phase 3: tear down probe graphs and mem_manager + self._teardown_graphs_and_kv() + reserved_before_empty = torch.cuda.memory_reserved() + torch.cuda.empty_cache() + reserved_after_empty = torch.cuda.memory_reserved() + released = reserved_before_empty - reserved_after_empty + + # Sanity: we should have released at least the probe kv_buffer. + # The threshold is probe_kv_bytes (strict, per spec §6.5 step 6) + # because the probe's kv_buffer is a single contiguous allocation + # and PyTorch's caching allocator returns whole segments on empty_cache. + if attempt == 1 and released < probe_kv_bytes: + raise RuntimeError( + f"auto-profile phase=rebuild TEARDOWN LEAK: " + f"empty_cache() only released {released / 1024 ** 3:.2f} GB " + f"but probe kv_buffer alone is {probe_kv_bytes / 1024 ** 3:.2f} GB. " + f"Some Python reference to the probe kv_buffer or a captured " + f"CUDA graph was not dropped by _teardown_graphs_and_kv. " + f"Investigate which attribute is leaking." + ) + + # Phase 3: re-init everything at target_tokens + self.max_total_token_num = target_tokens + self._init_mem_manager() + self._init_kv_move_buffer() + self._check_mem_size() + self._init_req_manager() + self._init_cudagraph() + self._init_prefill_cuda_graph() + + # Canary + self._oom_canary = torch.empty(CANARY_BYTES, dtype=torch.uint8, device="cuda") + + logger.info( + f"auto-profile phase=rebuild attempt={attempt} " + f"elapsed_sec={time.time() - t0:.2f} " + f"new_kv_tokens={target_tokens}" + ) + + # Phase 4: validate + t1 = time.time() + self._check_max_len_infer() + self._check_decode_infer() + logger.info( + f"auto-profile phase=validate attempt={attempt} " f"elapsed_sec={time.time() - t1:.2f} result=ok" + ) + return # success + except (RuntimeError, torch.cuda.OutOfMemoryError, torch.OutOfMemoryError) as e: + last_exc = e + logger.warning( + f"auto-profile phase=validate attempt={attempt} " + f"result={'retry' if attempt <= RETRY_BUDGET else 'fail'} " + f"error={type(e).__name__}: {e}" + ) + if attempt > RETRY_BUDGET: + break + target_tokens = int(target_tokens * SHRINK_RATIO) + + # All retries exhausted — raise a multi-knob diagnostic + total_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 + cell_size = None + try: + cell_size = self.mem_manager.get_cell_size() + except Exception: + pass + initial_gb = initial_target_tokens * (cell_size or 0) / 1024 ** 3 + final_gb = target_tokens * (cell_size or 0) / 1024 ** 3 + raise Exception( + f"Auto-profile failed after {attempt} attempts.\n" + f"Initial target: {initial_target_tokens} tokens ({initial_gb:.2f} GB KV)\n" + f"Final attempted: {target_tokens} tokens ({final_gb:.2f} GB KV)\n" + f"Measured peak: {peak_reserved / 1024 ** 3:.2f} GB\n" + f"Total GPU memory: {total_memory_gb:.2f} GB\n" + f"Canary reserve: {CANARY_BYTES / 1024 ** 3:.2f} GB\n" + f"\n" + f"The configured load does not fit on this device. Try:\n" + f" 1. --batch_max_tokens: reduce to lower prefill activation peak\n" + f" 2. --graph_max_batch_size: reduce to lower decode activation peak\n" + f" 3. --visual_infer_batch_size: reduce to lower ViT pinned footprint\n" + f" 4. --max_total_token_num: pin a specific KV size (skips auto-profile)\n" + f" 5. --mem_fraction: as a last resort, set < 1.0 to add extra safety margin\n" + f"\n" + f"Last error: {type(last_exc).__name__}: {last_exc}" + ) + def autotune_layers(self): # 控制autotune的层数,用于适配不同模型 return self.config.get("first_k_dense_replace", 0) + 1 diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..354b205795 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -33,6 +33,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype + self._probe_tokens = None + self._mem_fraction = mem_fraction # profile the max total token num if the size is None self.profile_size(mem_fraction) @@ -84,24 +86,110 @@ def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def profile_size(self, mem_fraction): + """ + Phase 1 of the two-pass auto-profile: pick a small-but-realistic + probe KV size for graph capture and stress measurement. + + - If self.size is already set (explicit --max_total_token_num, or + Phase 3's re-init after computing the target), this is a no-op. + - mem_fraction is saved for Phase 2's profile_size_target(), where + it acts as an optional additional safety multiplier on top of the + measured budget. + + See docs/superpowers/specs/2026-04-16-multimodal-oom-fix-design.md + sections 6.3 and 6.5 for the full design rationale. + """ if self.size is not None: return + from lightllm.utils.envs_utils import get_env_start_args + + start_args = get_env_start_args() + gmbs = start_args.graph_max_batch_size + bmt = start_args.batch_max_tokens + # Probe needs enough KV for: + # - one prefill stress (bmt slots — full chunk length) + # - one decode stress (gmbs slots, 1 new token per request) + # NOT gmbs * bmt — that would be the full production KV and defeat + # the purpose of the probe (measuring non-KV overhead with a small + # KV allocation). + # basemodel._check_mem_size relaxes its max_seq_length assertion + # when _probe_tokens is set, so the probe doesn't need to hold a + # full-length request — it only needs enough slots for the stress + # forwards. + self._probe_tokens = max(bmt + gmbs, 8192) + self.size = self._probe_tokens + self._mem_fraction = mem_fraction # redundant with __init__; kept so profile_size is readable in isolation + logger.info( + f"auto-profile phase=probe probe_tokens={self._probe_tokens} " + f"(gmbs={gmbs}, bmt={bmt}, mem_fraction={mem_fraction})" + ) + + def profile_size_target(self, peak_reserved_bytes): + """ + Phase 2 of the two-pass auto-profile: compute target KV size from + the measured `torch.cuda.max_memory_reserved()` peak. + + Formula (see spec §6.3): + non_kv_overhead = peak_reserved - probe_kv_bytes + peers_footprint = max(total - avail - own_reserved, 0) + budget = total - non_kv_overhead - canary - peers_footprint + budget *= mem_fraction # default 1.0 + target_tokens = int(budget / cell_size) + + In TP mode, target_tokens is all_reduce(MIN) across ranks so every + rank agrees on the smallest feasible size. + """ + if self._probe_tokens is None: + raise RuntimeError( + "profile_size_target called before profile_size set a probe. " + "This indicates the auto-profile escape-hatch (--max_total_token_num) " + "was taken — Phase 2 should not be reached in that path." + ) - world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + total_memory_bytes = int(get_total_gpu_memory() * 1024 ** 3) cell_size = self.get_cell_size() - self.size = int(available_memory * 1024 ** 3 / cell_size) + probe_kv_bytes = self._probe_tokens * cell_size + non_kv_overhead = peak_reserved_bytes - probe_kv_bytes + if non_kv_overhead < 0: + logger.warning( + f"auto-profile: peak_reserved ({peak_reserved_bytes}) < probe_kv_bytes ({probe_kv_bytes}). " + f"This suggests the allocator released probe blocks before measurement. " + f"Clamping non_kv_overhead to 0." + ) + non_kv_overhead = 0 + canary_bytes = 256 * 1024 * 1024 + + try: + world_size = dist.get_world_size() + except Exception: + world_size = 1 + avail_bytes = int(get_available_gpu_memory(world_size) * 1024 ** 3) + own_reserved = torch.cuda.memory_reserved() + peers_footprint = max(total_memory_bytes - avail_bytes - own_reserved, 0) + + budget = total_memory_bytes - non_kv_overhead - canary_bytes - peers_footprint + budget = int(budget * self._mem_fraction) + target_tokens = max(int(budget / cell_size), 1) + if world_size > 1: - tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") - dist.all_reduce(tensor, op=dist.ReduceOp.MIN) - self.size = tensor.item() + device = f"cuda:{get_current_device_id()}" + t = torch.tensor(target_tokens, dtype=torch.int64, device=device) + dist.all_reduce(t, op=dist.ReduceOp.MIN) + target_tokens = t.item() + logger.info( - f"{str(available_memory)} GB space is available after load the model weight\n" - f"{str(cell_size / 1024 ** 2)} MB is the size of one token kv cache\n" - f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" + f"auto-profile phase=measure " + f"peak_reserved_gb={peak_reserved_bytes / 1024 ** 3:.2f} " + f"non_kv_overhead_gb={non_kv_overhead / 1024 ** 3:.2f} " + f"peers_footprint_gb={peers_footprint / 1024 ** 3:.2f}" ) - return + logger.info( + f"auto-profile phase=compute " + f"target_tokens={target_tokens} " + f"target_kv_gb={target_tokens * cell_size / 1024 ** 3:.2f} " + f"mem_fraction_applied={self._mem_fraction}" + ) + return target_tokens def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 在初始化 kv_buffer 的时候,每层多初始化了一个 token,这个 token 永远不会被真的被对外 diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 8602f2e67e..9303273051 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -206,13 +206,18 @@ def profile_size( f"you can add `--disable_dynamic_prompt_cache` to avoid this error.", ) return - from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + from lightllm.utils.profile_max_tokens import get_available_gpu_memory import torch.distributed as dist - mem_fraction = start_args.mem_fraction world_size = dist.get_world_size() - total_memory = get_total_gpu_memory() - available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + # Do NOT subtract `total * (1 - mem_fraction)` here. Under the + # auto-profile design, the mem_fraction safety margin applies only + # to the KV cache budget in Phase 2, and the 256 MB LLM-side canary + # absorbs allocator jitter. Subtracting again here would double-count + # the safety margin and starve mamba on memory-tight configurations + # (e.g. Qwen3.5-122B on 80 GB cards where weights alone take ~60 GB + # per TP rank). + available_memory = get_available_gpu_memory(world_size) conv_cell_size = ( self.layer_num * self.conv_dim diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..ed2cd6c16c 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -157,6 +157,7 @@ def __init__( super().__init__() self.weight_dir = kvargs["weight_dir"] self.data_type = kvargs.get("data_type", "bfloat16") + self.max_batch_size = kvargs.get("max_batch_size", 1) self.depth = depth self.hidden_size = hidden_size @@ -224,6 +225,12 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return + @torch.no_grad() + def _check_max_len_infer(self): + from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer + + qwen_vl_check_max_len_infer(self, self.max_batch_size) + def rot_pos_emb(self, grid_thw): pos_ids = [] s = self.spatial_merge_size diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 6076756043..3a7d04fdbd 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -193,6 +193,7 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") + self.max_batch_size = kvargs.get("max_batch_size", 1) self.depth = depth self.embed_dim = embed_dim @@ -238,6 +239,12 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return + @torch.no_grad() + def _check_max_len_infer(self): + from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer + + qwen_vl_check_max_len_infer(self, self.max_batch_size) + def load_model(self, weight_dir): processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index bc313fe467..9324fab4e7 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -1,5 +1,6 @@ from __future__ import annotations import math +import os import torch import numpy as np from PIL import Image @@ -27,6 +28,59 @@ logger = init_logger(__name__) +def closest_factor_pair(n): + """Find the factor pair of n closest to sqrt(n). Returns (smaller, larger).""" + sqrt_n = int(math.sqrt(n)) + for i in range(sqrt_n, 0, -1): + if n % i == 0: + return i, n // i + return 1, n + + +@torch.no_grad() +def qwen_vl_check_max_len_infer(model, max_batch_size): + """OOM pre-check for Qwen-family vision models. + + Constructs worst-case dummy images at max_pixels resolution, replicates + for max_batch_size, and runs a forward pass. Holds the stress peak in + the PyTorch caching allocator for the rest of process lifetime by + deliberately NOT calling torch.cuda.empty_cache() — the Python refs + are dropped, but the driver view continues to see the reservation. + """ + unit = model.patch_size * model.spatial_merge_size + max_pixels = model.processor.max_pixels + max_patches = max_pixels // (unit * unit) + if max_patches < 1: + max_patches = 1 + h_factor, w_factor = closest_factor_pair(max_patches) + worst_h = unit * h_factor + worst_w = unit * w_factor + + try: + dummy_image = Image.new("RGB", (worst_w, worst_h), color=(128, 128, 128)) + pixel_values, grid_thw = model.processor.preprocess(dummy_image) + + pixel_values = pixel_values.repeat(max_batch_size, 1, 1) + grid_thw = grid_thw.repeat(max_batch_size, 1) + + pixel_values = pixel_values.to("cuda", dtype=model.data_type, non_blocking=True) + grid_thw = grid_thw.to("cuda", non_blocking=True) + + result = model.forward(pixel_values, grid_thw=grid_thw) + del result, pixel_values, grid_thw + # Deliberately NOT calling torch.cuda.empty_cache() — we want the + # stress peak to stay pinned at the driver level so the LLM + # subprocess's later get_available_gpu_memory sees it as reserved. + logger.info(f"vit check max_len {max_batch_size} infer ok") + except (RuntimeError, torch.OutOfMemoryError, ValueError): + logger.exception("Qwen VL check max len infer failed") + exception_str = ( + "Vit check max len infer fail, you can try: " "1.Set the --visual_infer_batch_size to a smaller value." + ) + logger.error(exception_str) + raise RuntimeError(exception_str) + + IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index 0276724749..5700933031 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -140,6 +140,7 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") + self.max_batch_size = kvargs.get("max_batch_size", 1) self.depth = depth self.out_hidden_size = out_hidden_size @@ -207,6 +208,12 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return + @torch.no_grad() + def _check_max_len_infer(self): + from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer + + qwen_vl_check_max_len_infer(self, self.max_batch_size) + def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids): all_chunks = [] diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bed8898115..4f6b1e119c 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -136,6 +136,7 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") + self.max_batch_size = kvargs.get("max_batch_size", 1) self.depth = depth self.out_hidden_size = out_hidden_size @@ -202,6 +203,12 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return + @torch.no_grad() + def _check_max_len_infer(self): + from lightllm.models.qwen2_vl.vision_process import qwen_vl_check_max_len_infer + + qwen_vl_check_max_len_infer(self, self.max_batch_size) + def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids): all_chunks = [] diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 13f8e2827f..bd902cc034 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -53,16 +53,11 @@ def __init__(self, kvargs): self._init_quant() self._init_weights() self._init_infer_layer() - self._check_max_len_infer() return @final @torch.no_grad() def _check_max_len_infer(self): - disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None - if disable_check_max_len_infer: - return - try: dummy_images = torch.randn( (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type @@ -73,7 +68,7 @@ def _check_max_len_infer(self): except (RuntimeError, torch.OutOfMemoryError) as e: logger.exception(str(e)) exception_str = ( - "Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value." + "Vit check max len infer fail, you can try: 1.Set the --visual_infer_batch_size to a smaller value." ) logger.error(exception_str) raise Exception(exception_str) @@ -85,16 +80,27 @@ def _init_config(self): self.select_layer = self.config["select_layer"] self.config["vision_config"]["llm_hidden_size"] = self.config["llm_config"]["hidden_size"] self.config["vision_config"]["downsample_ratio"] = self.config["downsample_ratio"] + + # Derive worst-case image dimensions from model config + image_size = self.config.get("force_image_size", self.config["vision_config"]["image_size"]) + max_dynamic_patch = self.config.get("max_dynamic_patch", 12) + use_thumbnail = self.config.get("use_thumbnail", True) + dynamic_image_size = self.config.get("dynamic_image_size", True) + self.config = self.config["vision_config"] + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) self.layers_num = self.config["num_hidden_layers"] - # infer info - self.IMAGE_H = int(os.getenv("IMAGE_H", 448)) - self.IMAGE_W = int(os.getenv("IMAGE_W", 448)) - self.MAX_PATH_NUM = os.getenv("MAX_PATH_NUM", 13) + # infer info — computed from config, not env vars + self.IMAGE_H = image_size + self.IMAGE_W = image_size + max_num = max_dynamic_patch if dynamic_image_size else 1 + if use_thumbnail and max_num != 1: + max_num += 1 + self.MAX_PATH_NUM = max_num return def _padding_hidden_size(self): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7dcd7df1bb..6fe48d44e9 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -135,9 +135,14 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--mem_fraction", type=float, - default=0.9, - help="""Memory usage ratio, default is 0.9, you can specify a smaller value if OOM occurs at runtime. - If max_total_token_num is not specified, it will be calculated automatically based on this value.""", + default=0.95, + help=( + "Safety multiplier applied on top of the auto-profiled KV cache " + "budget. Default 0.95 reserves 5%% extra headroom for per-request " + "spikes and allocator fragmentation the stress test cannot cover. " + "Set 1.0 to use the full measured budget (aggressive). " + "Ignored when --max_total_token_num is set explicitly." + ), ) parser.add_argument( "--batch_max_tokens", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index be8f8fe682..33a39f3eed 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -136,8 +136,8 @@ def normal_or_p_d_start(args): assert args.config_server_host == args.nccl_host assert ( - args.mem_fraction > 0 and args.mem_fraction < 1 - ), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1." + args.mem_fraction > 0 and args.mem_fraction <= 1 + ), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1 (inclusive)." if args.graph_max_len_in_batch == 0: args.graph_max_len_in_batch = args.max_req_total_len diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 39a7e06ac3..7acf817e25 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -93,7 +93,7 @@ def _log_latency(self, audio: AudioItem, stage: str): def _init_taskes(self): self.infer_queue = queue.Queue() self.store_queue = queue.Queue() - self.sempare = threading.Semaphore(self.infer_max_batch_size * 8) + self.sempare = threading.Semaphore(self.infer_max_batch_size) self.gloo_group = dist.new_group(ranks=list(range(self.audio_tp)), backend="gloo") self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 49a113b1ba..9a7cdea689 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -291,8 +291,6 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.mtp_step = self.args.mtp_step self.draft_models = [] - os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: num_mtp_modules = self.args.mtp_step elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a165be78f2..1dffdaf681 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -84,7 +84,7 @@ async def wait_to_model_ready(self): "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "max_batch_size": max(self.infer_batch_size // self.vit_dp, 1), "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 55f4704a31..c1242757e4 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -111,6 +111,13 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() + if hasattr(self.model, "_check_max_len_infer"): + self.model._check_max_len_infer() + else: + logger.warning( + f"no stress test available for visual model type '{self.model_type}'; " + f"the LLM subprocess's auto-profile is the only remaining OOM defense" + ) if not self.is_visual_only_mode: self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -168,7 +175,7 @@ def _init_taskes(self): self.store_queue = queue.Queue() # 限制并发, 主要是为了控制内存用量,防止过多造成内存OOM - self.sempare = threading.Semaphore(self.infer_max_batch_size * 8) + self.sempare = threading.Semaphore(self.infer_max_batch_size) # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/test_auto_profile.py b/unit_tests/common/basemodel/test_auto_profile.py new file mode 100644 index 0000000000..0c4b14074b --- /dev/null +++ b/unit_tests/common/basemodel/test_auto_profile.py @@ -0,0 +1,284 @@ +"""Pure-Python unit tests for the LLM auto-profile path. + +These tests exercise the memory manager's probe sizing, target arithmetic, +and the basemodel init's retry loop. They stub torch.cuda.* and +get_available_gpu_memory via monkeypatch so they run on CPU without any +GPU-related setup. They are the only pure-Python tests in +unit_tests/common/basemodel/ at the time of writing — the other tests +in that directory are triton kernel tests that require a real GPU. +""" +import pytest +from unittest import mock + + +class _StubStartArgs: + def __init__(self, graph_max_batch_size, batch_max_tokens, max_req_total_len=2048): + self.graph_max_batch_size = graph_max_batch_size + self.batch_max_tokens = batch_max_tokens + self.max_req_total_len = max_req_total_len + + +@pytest.fixture +def stub_env_start_args(monkeypatch): + def _install(graph_max_batch_size, batch_max_tokens, max_req_total_len=2048): + stub = _StubStartArgs(graph_max_batch_size, batch_max_tokens, max_req_total_len) + monkeypatch.setattr( + "lightllm.utils.envs_utils.get_env_start_args", + lambda: stub, + ) + return stub + + return _install + + +def _make_bare_mem_manager(): + """Construct a MemoryManager instance without running __init__. + + This lets us call profile_size() / profile_size_target() on a plain + object with only the fields those methods touch, avoiding the need + to initialize CUDA, distributed, or shared-memory state. + """ + from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager + + mgr = MemoryManager.__new__(MemoryManager) + mgr.size = None + mgr._probe_tokens = None + mgr._mem_fraction = 1.0 + # Cell size only matters when profile_size_target() runs — not probe. + mgr.head_num = 8 + mgr.head_dim = 64 + mgr.layer_num = 32 + import torch + + mgr.dtype = torch.float16 + return mgr + + +def test_profile_size_probe_formula_large_bmt(stub_env_start_args): + """Probe size = bmt + gmbs when that exceeds the 8192 floor. + The probe is independent of max_req_total_len — _check_mem_size's + max_seq_length assertion is relaxed during probe phase.""" + stub_env_start_args(graph_max_batch_size=128, batch_max_tokens=16384, max_req_total_len=262144) + mgr = _make_bare_mem_manager() + mgr.profile_size(mem_fraction=1.0) + assert mgr.size == 16384 + 128 + assert mgr._probe_tokens == 16384 + 128 + assert mgr._mem_fraction == 1.0 + + +def test_profile_size_probe_formula_tiny_config(stub_env_start_args): + """Probe size floors to 8192 when bmt+gmbs is smaller.""" + stub_env_start_args(graph_max_batch_size=1, batch_max_tokens=128, max_req_total_len=1024) + mgr = _make_bare_mem_manager() + mgr.profile_size(mem_fraction=1.0) + assert mgr.size == 8192 + + +def test_profile_size_early_return_when_size_preset(stub_env_start_args): + """If size is already set (e.g. --max_total_token_num), profile_size is a no-op.""" + stub_env_start_args(graph_max_batch_size=64, batch_max_tokens=4096) + mgr = _make_bare_mem_manager() + mgr.size = 131072 + mgr.profile_size(mem_fraction=0.9) + assert mgr.size == 131072 + assert mgr._probe_tokens is None # not touched + + +def test_profile_size_target_arithmetic(monkeypatch, stub_env_start_args): + """profile_size_target computes target from peak, peers, canary, budget.""" + stub_env_start_args(graph_max_batch_size=64, batch_max_tokens=4096) + + mgr = _make_bare_mem_manager() + mgr.profile_size(mem_fraction=1.0) # picks probe + probe_tokens = mgr._probe_tokens + cell_size = mgr.get_cell_size() + probe_kv_bytes = probe_tokens * cell_size + + # Set up a synthetic 80 GB card. + TOTAL_GB = 80.0 + total_bytes = int(TOTAL_GB * 1024 ** 3) + # Peer footprint: 10 GB worth of ViT + audio driver reservation. + peer_bytes = int(10 * 1024 ** 3) + # Own reserved: weights + probe KV + graphs + stress activations. Say 35 GB. + own_reserved_bytes = int(35 * 1024 ** 3) + # Peak reserved (high-water-mark after stress) = own_reserved in this model. + peak_reserved = own_reserved_bytes + + monkeypatch.setattr( + "lightllm.common.kv_cache_mem_manager.mem_manager.get_total_gpu_memory", + lambda: TOTAL_GB, + ) + monkeypatch.setattr( + "lightllm.common.kv_cache_mem_manager.mem_manager.get_available_gpu_memory", + lambda world_size=1: (total_bytes - own_reserved_bytes - peer_bytes) / 1024 ** 3, + ) + import torch + + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda: own_reserved_bytes) + monkeypatch.setattr( + "torch.distributed.get_world_size", + lambda: 1, + ) + + target = mgr.profile_size_target(peak_reserved) + + non_kv_overhead = peak_reserved - probe_kv_bytes + canary_bytes = 256 * 1024 * 1024 + expected_budget = total_bytes - non_kv_overhead - canary_bytes - peer_bytes + expected_target = int(expected_budget / cell_size) + + assert target == expected_target + # Sanity: target is strictly larger than the probe (the whole point). + assert target > probe_tokens + + +def test_profile_size_target_peer_footprint_floors_to_zero(monkeypatch, stub_env_start_args): + """If get_available_gpu_memory says more is available than total-own_reserved, + the peer footprint must floor to 0 (the arithmetic produced a negative number).""" + stub_env_start_args(graph_max_batch_size=1, batch_max_tokens=128) + mgr = _make_bare_mem_manager() + mgr.profile_size(mem_fraction=1.0) + + TOTAL_GB = 80.0 + + monkeypatch.setattr( + "lightllm.common.kv_cache_mem_manager.mem_manager.get_total_gpu_memory", + lambda: TOTAL_GB, + ) + # avail > total - own_reserved, i.e. peer_footprint would be negative + monkeypatch.setattr( + "lightllm.common.kv_cache_mem_manager.mem_manager.get_available_gpu_memory", + lambda world_size=1: TOTAL_GB, # "everything is available" + ) + import torch + + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda: 0) + monkeypatch.setattr("torch.distributed.get_world_size", lambda: 1) + + target = mgr.profile_size_target(peak_reserved_bytes=1024 * 1024) + assert target > 0 # didn't crash on negative peer footprint + + +def test_profile_size_target_mem_fraction_multiplier(monkeypatch, stub_env_start_args): + """--mem_fraction 0.95 should produce a target 95% the size of the default 1.0.""" + stub_env_start_args(graph_max_batch_size=64, batch_max_tokens=4096) + + TOTAL_GB = 80.0 + total_bytes = int(TOTAL_GB * 1024 ** 3) + own_reserved = int(35 * 1024 ** 3) + peak_reserved = own_reserved + + def _patch(mgr): + monkeypatch.setattr( + "lightllm.common.kv_cache_mem_manager.mem_manager.get_total_gpu_memory", + lambda: TOTAL_GB, + ) + monkeypatch.setattr( + "lightllm.common.kv_cache_mem_manager.mem_manager.get_available_gpu_memory", + lambda world_size=1: (total_bytes - own_reserved) / 1024 ** 3, + ) + import torch + + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda: own_reserved) + monkeypatch.setattr("torch.distributed.get_world_size", lambda: 1) + + mgr_default = _make_bare_mem_manager() + mgr_default.profile_size(mem_fraction=1.0) + _patch(mgr_default) + target_default = mgr_default.profile_size_target(peak_reserved) + + mgr_paranoid = _make_bare_mem_manager() + mgr_paranoid.profile_size(mem_fraction=0.95) + _patch(mgr_paranoid) + target_paranoid = mgr_paranoid.profile_size_target(peak_reserved) + + # Paranoid target is 95% of default target (± rounding). + ratio = target_paranoid / target_default + assert 0.94 < ratio <= 0.95 + + +def test_auto_profile_retry_budget_respects_cap(monkeypatch): + """The rebuild/validate loop retries at most 3 times (4 total attempts) + before raising a multi-knob diagnostic exception. + """ + from lightllm.common.basemodel.basemodel import TpPartBaseModel + + model = TpPartBaseModel.__new__(TpPartBaseModel) + model.mem_manager = mock.MagicMock() + model.mem_manager._probe_tokens = 8192 + model.mem_manager.profile_size_target.return_value = 100000 + model.mem_manager.get_cell_size.return_value = 64 + model.max_total_token_num = None + + # Stub teardown / re-init / graph re-capture so they succeed without CUDA. + model._teardown_graphs_and_kv = mock.MagicMock() + model._init_mem_manager = mock.MagicMock(side_effect=lambda: setattr(model, "mem_manager", model.mem_manager)) + model._init_kv_move_buffer = mock.MagicMock() + model._check_mem_size = mock.MagicMock() + model._init_req_manager = mock.MagicMock() + model._init_cudagraph = mock.MagicMock() + model._init_prefill_cuda_graph = mock.MagicMock() + + # Always-OOM stress so we exhaust the retry budget. + model._check_max_len_infer = mock.MagicMock(side_effect=RuntimeError("CUDA out of memory (synthetic)")) + model._check_decode_infer = mock.MagicMock() + + import torch + + monkeypatch.setattr(torch.cuda, "max_memory_reserved", lambda: 1024 ** 3) + # memory_reserved is called twice per attempt: before and after empty_cache. + # We need (before - after) >= probe_kv_bytes (8192 * 64 = 524288) so the + # teardown-leak guard on attempt 1 is satisfied. + _mem_reserved_calls = [0] + + def _memory_reserved(): + _mem_reserved_calls[0] += 1 + # Odd calls (before empty_cache) return 1 GB; even calls (after) return 0. + return 1024 ** 3 if _mem_reserved_calls[0] % 2 == 1 else 0 + + monkeypatch.setattr(torch.cuda, "memory_reserved", _memory_reserved) + monkeypatch.setattr(torch.cuda, "empty_cache", lambda: None) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda dev: mock.Mock(total_memory=80 * 1024 ** 3)) + monkeypatch.setattr(torch, "empty", lambda *a, **kw: mock.MagicMock()) + + with pytest.raises(Exception) as exc_info: + model._auto_profile_rebuild_and_validate() + + msg = str(exc_info.value) + assert "Auto-profile failed after 4 attempts" in msg + # Every knob the diagnostic promises must be named. + assert "--batch_max_tokens" in msg + assert "--graph_max_batch_size" in msg + assert "--visual_infer_batch_size" in msg + assert "--max_total_token_num" in msg + assert "--mem_fraction" in msg + # Confirm each attempt actually ran the stress + assert model._check_max_len_infer.call_count == 4 + + +def test_auto_profile_explicit_max_total_token_num_skips_rebuild(monkeypatch): + """If _probe_tokens is None (explicit --max_total_token_num path), + the rebuild loop is a no-op beyond one empty_cache call. + """ + from lightllm.common.basemodel.basemodel import TpPartBaseModel + + model = TpPartBaseModel.__new__(TpPartBaseModel) + model.mem_manager = mock.MagicMock() + model.mem_manager._probe_tokens = None # escape hatch + + model._teardown_graphs_and_kv = mock.MagicMock() + model._init_mem_manager = mock.MagicMock() + model._check_max_len_infer = mock.MagicMock() + model._check_decode_infer = mock.MagicMock() + + import torch + + empty_cache_calls = [] + monkeypatch.setattr(torch.cuda, "empty_cache", lambda: empty_cache_calls.append(1)) + monkeypatch.setattr(torch.cuda, "max_memory_reserved", lambda: 1) + + model._auto_profile_rebuild_and_validate() + + assert empty_cache_calls == [1] # exactly one call + assert model._teardown_graphs_and_kv.call_count == 0 + assert model._init_mem_manager.call_count == 0