Skip to content
Open
289 changes: 282 additions & 7 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gc
import copy
import json
import time
import torch
import torch.nn.functional as F
import triton
Expand Down Expand Up @@ -134,7 +135,8 @@ def __init__(self, kvargs):
self._init_cudagraph()
self._init_prefill_cuda_graph()
self._check_max_len_infer()
torch.cuda.empty_cache()
self._check_decode_infer()
self._auto_profile_rebuild_and_validate()
set_model_init_status(True)
return

Expand Down Expand Up @@ -212,7 +214,12 @@ def _init_kv_move_buffer(self):

def _check_mem_size(self):
self.max_total_token_num = self.mem_manager.size
assert self.max_seq_length <= self.max_total_token_num
# Skip the max_seq_length assertion during the auto-profile probe
# phase: the probe KV is intentionally small (just enough for the
# stress forwards), not sized to hold a full-length request. The
# assertion is re-checked after Phase 3 rebuilds at the target size.
if getattr(self.mem_manager, "_probe_tokens", None) is None or self.mem_manager.size >= self.max_seq_length:
assert self.max_seq_length <= self.max_total_token_num
return

def _init_req_manager(self):
Expand Down Expand Up @@ -944,11 +951,6 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1:
@final
@torch.no_grad()
def _check_max_len_infer(self):
disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
if disable_check_max_len_infer:
logger.info("disable_check_max_len_infer is true")
return

# 做一次 同步
torch.distributed.barrier()

Expand Down Expand Up @@ -1003,6 +1005,279 @@ def _check_max_len_infer(self):
raise Exception(exception_str)
return

@torch.no_grad()
def _allocate_decode_stress_slots(self, batch_size):
"""Allocate req/mem slots for the decode stress forward.

Override this in mamba-aware subclasses (e.g. qwen3next) to also
allocate mamba buffers via req_manager.alloc_buffer_for_req so the
GDN layers exercise their full memory footprint. The base version
only touches the standard req/mem managers.

Returns a tuple (actual_batch, req_idxs, tokens_per_req, total_tokens,
mem_indexes) that _build_decode_model_input consumes.
"""
req_idxs = []
for _ in range(batch_size):
idx = self.req_manager.alloc()
if idx is None:
break
req_idxs.append(idx)
actual_batch = len(req_idxs)
if actual_batch < 2:
return actual_batch, req_idxs, 0, 0, None

tokens_per_req = min(
self.batch_max_tokens,
max(1, self.max_total_token_num // actual_batch),
)
total_tokens = tokens_per_req * actual_batch
total_tokens = min(total_tokens, self.mem_manager.can_use_mem_size)
tokens_per_req = max(1, total_tokens // actual_batch)
total_tokens = tokens_per_req * actual_batch

mem_indexes = self.mem_manager.alloc(total_tokens).cuda()
return actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes

def _build_decode_model_input(self, actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes):
"""Construct a decode-shaped ModelInput for the stress forward.

Split out so mamba subclasses can also override just the ModelInput
shape if needed, without touching _check_decode_infer's control flow.
"""
dummy_input_ids = torch.ones(actual_batch, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor(req_idxs, dtype=torch.int32, device="cuda")
b_seq_len = torch.full((actual_batch,), tokens_per_req, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(actual_batch, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(actual_batch, dtype=torch.int32, device="cuda")
return ModelInput(
batch_size=actual_batch,
total_token_num=total_tokens,
max_q_seq_len=1,
max_kv_seq_len=tokens_per_req,
max_cache_len=tokens_per_req - 1,
prefix_total_token_num=0,
input_ids=dummy_input_ids,
# mem_indexes[:actual_batch] provides 1 new KV slot per request for the decode
# step's output token. The full total_tokens block was allocated from mem_manager
# to occupy the KV cache space, but only 1 slot per request is the "new" token.
mem_indexes=mem_indexes[:actual_batch],
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
is_prefill=False,
b_ready_cache_len=b_ready_cache_len,
multimodal_params=[{"images": [], "audios": []}] * actual_batch,
)

@torch.no_grad()
def _check_decode_infer(self):
"""Simulate a decode batch to detect OOM from concurrent request activations.

Ported from origin/qw35_stable:basemodel.py:895 and split into
overridable sub-methods per spec §6.4 so qwen3next (when it lands on
main) can cleanly hook in mamba buffer allocation.

The prob_out.sort() call at the end is load-bearing: it forces the
top_p/top_k sampling allocation that real inference triggers but a
naive decode forward does not. On a vocab=152k model with
graph_max_batch_size=64 this accounts for ~20 MB per decode step, so
skipping it under-measures the peak by exactly the amount that causes
the "first decode batch after warmup OOMs" bug class.
"""
torch.distributed.barrier()

batch_size = self.graph_max_batch_size
if batch_size <= 1:
return

try:
logger.info(f"begin check decode infer with batch_size={batch_size}")

actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes = self._allocate_decode_stress_slots(
batch_size
)
if actual_batch < 2:
logger.info("skip decode check: not enough req slots")
self.req_manager.free_all()
self.mem_manager.free_all()
return

model_input = self._build_decode_model_input(
actual_batch, req_idxs, tokens_per_req, total_tokens, mem_indexes
)
model_output = self.forward(model_input)
prob_out = torch.softmax(model_output.logits, dim=-1)
del model_output
# Force top_p/top_k sampling allocation — load-bearing for peak measurement.
prob_out.sort(dim=-1, descending=True)
prob_out = None
self.req_manager.free_all()
self.mem_manager.free_all()
logger.info(f"check decode {actual_batch} infer ok")
except (RuntimeError, torch.OutOfMemoryError) as e:
logger.exception(str(e))
exception_str = (
"check decode infer fail, you can try:\n"
"1. Set --graph_max_batch_size to a smaller value.\n"
"2. Set --mem_fraction or --max_total_token_num to a smaller value.\n"
"3. Set --max_req_total_len to a smaller value."
)
logger.error(exception_str)
raise Exception(exception_str)
return

def _teardown_graphs_and_kv(self):
"""Phase 3 teardown: drop references to the probe's kv_buffer and
all captured CUDA graphs so torch.cuda.empty_cache() can actually
return the blocks to the driver.

Order matters — graphs hold tensor pointers into kv_buffer and must
go first, then req_manager (which references mem_manager), then
mem_manager itself. See spec §6.5 for the rationale.
"""
if hasattr(self, "mem_manager") and self.mem_manager is not None:
try:
self.mem_manager.free_all()
except Exception:
pass

for attr in ("graph", "prefill_graph", "prefill_cuda_graph"):
if hasattr(self, attr):
setattr(self, attr, None)
# MTP variants (if present)
for attr in ("graph1", "prefill_graph1"):
if hasattr(self, attr):
setattr(self, attr, None)

if hasattr(self, "req_manager"):
self.req_manager = None
if hasattr(self, "mem_manager"):
self.mem_manager = None
Comment on lines +1130 to +1156
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _teardown_graphs_and_kv method is missing the cleanup of self.att_backend (and self.att_backend1 if present) and self._oom_canary. If these references are not dropped, torch.cuda.empty_cache() will not be able to release the memory occupied by the probe KV buffer and the canary tensor, which will cause the teardown leak check to fail or lead to OOM during the rebuild phase.

    def _teardown_graphs_and_kv(self):
        """Phase 3 teardown: drop references to the probe's kv_buffer and
        all captured CUDA graphs so torch.cuda.empty_cache() can actually
        return the blocks to the driver.

        Order matters — graphs hold tensor pointers into kv_buffer and must
        go first, then req_manager (which references mem_manager), then
        mem_manager itself. See spec §6.5 for the rationale.
        """
        if hasattr(self, "mem_manager") and self.mem_manager is not None:
            try:
                self.mem_manager.free_all()
            except Exception:
                pass

        for attr in ("graph", "prefill_graph", "prefill_cuda_graph", "graph1", "prefill_graph1"):
            if hasattr(self, attr):
                setattr(self, attr, None)

        if hasattr(self, "att_backend"):
            self.att_backend = None
        if hasattr(self, "att_backend1"):
            self.att_backend1 = None
        if hasattr(self, "_oom_canary"):
            self._oom_canary = None

        if hasattr(self, "req_manager"):
            self.req_manager = None
        if hasattr(self, "mem_manager"):
            self.mem_manager = None


def _auto_profile_rebuild_and_validate(self):
"""Phases 2-4 of the auto-profile loop.

Runs after Phase 1 (__init__ through _check_decode_infer). Measures
the probe's peak, computes the target KV size, tears down the probe's
graphs and mem_manager, calls torch.cuda.empty_cache() once, re-inits
everything at the target size, re-captures graphs, allocates the 256 MB
canary, and validates by re-running the stress forwards.

On validation OOM, shrinks target_tokens by 5% and loops. Retry budget:
3 retries (4 total attempts). After that, raises with a multi-knob
diagnostic (see spec §7.2).

If --max_total_token_num was set explicitly (probe path was skipped),
this method is a pure no-op beyond a single empty_cache call to match
the pre-auto-profile behavior.
"""
if self.mem_manager._probe_tokens is None:
logger.info("auto-profile phase=skip reason=explicit_max_total_token_num")
torch.cuda.empty_cache()
return

peak_reserved = torch.cuda.max_memory_reserved()
initial_target_tokens = self.mem_manager.profile_size_target(peak_reserved)
target_tokens = initial_target_tokens
probe_tokens = self.mem_manager._probe_tokens
probe_kv_bytes = probe_tokens * self.mem_manager.get_cell_size()

RETRY_BUDGET = 3
SHRINK_RATIO = 0.95
CANARY_BYTES = 256 * 1024 * 1024

attempt = 0
last_exc = None
while attempt <= RETRY_BUDGET:
attempt += 1
t0 = time.time()
try:
# Phase 3: tear down probe graphs and mem_manager
self._teardown_graphs_and_kv()
reserved_before_empty = torch.cuda.memory_reserved()
torch.cuda.empty_cache()
reserved_after_empty = torch.cuda.memory_reserved()
released = reserved_before_empty - reserved_after_empty

# Sanity: we should have released at least the probe kv_buffer.
# The threshold is probe_kv_bytes (strict, per spec §6.5 step 6)
# because the probe's kv_buffer is a single contiguous allocation
# and PyTorch's caching allocator returns whole segments on empty_cache.
if attempt == 1 and released < probe_kv_bytes:
raise RuntimeError(
f"auto-profile phase=rebuild TEARDOWN LEAK: "
f"empty_cache() only released {released / 1024 ** 3:.2f} GB "
f"but probe kv_buffer alone is {probe_kv_bytes / 1024 ** 3:.2f} GB. "
f"Some Python reference to the probe kv_buffer or a captured "
f"CUDA graph was not dropped by _teardown_graphs_and_kv. "
f"Investigate which attribute is leaking."
)

# Phase 3: re-init everything at target_tokens
self.max_total_token_num = target_tokens
self._init_mem_manager()
self._init_kv_move_buffer()
self._check_mem_size()
self._init_req_manager()
self._init_cudagraph()
self._init_prefill_cuda_graph()
Comment on lines +1217 to +1224
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The rebuild loop in Phase 3 is missing calls to self._init_att_backend() and self._init_padded_req(). Since mem_manager and req_manager are recreated, the attention backend and the padded request (which hold references to the old managers or their internal state like KV buffer pointers and request slots) must also be re-initialized. Without this, the model will continue to use stale state pointing to the released probe memory, leading to crashes or incorrect results during validation and service.

Suggested change
# Phase 3: re-init everything at target_tokens
self.max_total_token_num = target_tokens
self._init_mem_manager()
self._init_kv_move_buffer()
self._check_mem_size()
self._init_req_manager()
self._init_cudagraph()
self._init_prefill_cuda_graph()
# Phase 3: re-init everything at target_tokens
self.max_total_token_num = target_tokens
self._init_mem_manager()
self._init_kv_move_buffer()
self._check_mem_size()
self._init_req_manager()
self._init_att_backend()
if hasattr(self, "_init_att_backend1"):
self._init_att_backend1()
self._init_padded_req()
self._init_cudagraph()
self._init_prefill_cuda_graph()


# Canary
self._oom_canary = torch.empty(CANARY_BYTES, dtype=torch.uint8, device="cuda")

logger.info(
f"auto-profile phase=rebuild attempt={attempt} "
f"elapsed_sec={time.time() - t0:.2f} "
f"new_kv_tokens={target_tokens}"
)

# Phase 4: validate
t1 = time.time()
self._check_max_len_infer()
self._check_decode_infer()
logger.info(
f"auto-profile phase=validate attempt={attempt} " f"elapsed_sec={time.time() - t1:.2f} result=ok"
)
return # success
except (RuntimeError, torch.cuda.OutOfMemoryError, torch.OutOfMemoryError) as e:
last_exc = e
logger.warning(
f"auto-profile phase=validate attempt={attempt} "
f"result={'retry' if attempt <= RETRY_BUDGET else 'fail'} "
f"error={type(e).__name__}: {e}"
)
if attempt > RETRY_BUDGET:
break
target_tokens = int(target_tokens * SHRINK_RATIO)

# All retries exhausted — raise a multi-knob diagnostic
total_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
cell_size = None
try:
cell_size = self.mem_manager.get_cell_size()
except Exception:
pass
initial_gb = initial_target_tokens * (cell_size or 0) / 1024 ** 3
final_gb = target_tokens * (cell_size or 0) / 1024 ** 3
raise Exception(
f"Auto-profile failed after {attempt} attempts.\n"
f"Initial target: {initial_target_tokens} tokens ({initial_gb:.2f} GB KV)\n"
f"Final attempted: {target_tokens} tokens ({final_gb:.2f} GB KV)\n"
f"Measured peak: {peak_reserved / 1024 ** 3:.2f} GB\n"
f"Total GPU memory: {total_memory_gb:.2f} GB\n"
f"Canary reserve: {CANARY_BYTES / 1024 ** 3:.2f} GB\n"
f"\n"
f"The configured load does not fit on this device. Try:\n"
f" 1. --batch_max_tokens: reduce to lower prefill activation peak\n"
f" 2. --graph_max_batch_size: reduce to lower decode activation peak\n"
f" 3. --visual_infer_batch_size: reduce to lower ViT pinned footprint\n"
f" 4. --max_total_token_num: pin a specific KV size (skips auto-profile)\n"
f" 5. --mem_fraction: as a last resort, set < 1.0 to add extra safety margin\n"
f"\n"
f"Last error: {type(last_exc).__name__}: {last_exc}"
)

def autotune_layers(self):
# 控制autotune的层数,用于适配不同模型
return self.config.get("first_k_dense_replace", 0) + 1
Expand Down
Loading
Loading