diff --git a/docs/get_started/installation.md b/docs/get_started/installation.md index e37169a11..c3338e31d 100644 --- a/docs/get_started/installation.md +++ b/docs/get_started/installation.md @@ -24,3 +24,18 @@ uv pip install -v . --prerelease=allow ```bash pip install specforge ``` + +- **Install on Ascend NPU** + +1. Pull compatible SGLang image for Ascend NPU, currently `quay.io/ascend/sglang:v0.5.9-cann8.5.0-a3` on A3 device, or `quay.io/ascend/sglang:v0.5.9-cann8.5.0-910b` on A2 device. +2. Install SpecForge. + +```bash +# git clone the source code +git clone https://github.com/sgl-project/SpecForge.git +cd SpecForge + +# install specforge +pip install -r requirements-npu.txt +pip install . --no-deps +``` diff --git a/examples/run_llama3.1_8b_eagle3_online_npu.sh b/examples/run_llama3.1_8b_eagle3_online_npu.sh new file mode 100644 index 000000000..fbd726c98 --- /dev/null +++ b/examples/run_llama3.1_8b_eagle3_online_npu.sh @@ -0,0 +1,30 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# Currently we only train with --max-length 2048 due to OOM issue on A3(64GB) +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend sglang \ + --log-interval 10 \ + --sglang-mem-fraction-static 0.3 diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 000000000..cb106f1f3 --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,24 @@ +# Use the PyTorch CPU wheel index cause torch_npu depends on a CPU version of PyTorch +--extra-index-url https://download.pytorch.org/whl/cpu + +pre-commit +torch==2.8.0+cpu +torch_npu==2.8.0.post2 +torchaudio==2.8.0 +torchvision==0.23.0 +transformers==4.57.1 +qwen-vl-utils==0.0.11 +datasets +setuptools +tqdm +wandb +psutil +numpy +accelerate +pydantic +sglang==0.5.9 +openai-harmony +ninja +packaging +yunchang +tensorboard diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index 30ce91942..3d4831536 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -59,6 +59,9 @@ ) from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model from specforge.utils import ( + empty_cache, + get_device_type, + get_local_device, print_args_with_dots, print_with_rank, rank_0_priority, @@ -183,7 +186,7 @@ def build_target_model( ), ) .eval() - .cuda() + .to(device=get_local_device()) ) else: target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() @@ -195,7 +198,7 @@ def build_target_model( if hasattr(model_config, "dtype") else model_config.torch_dtype ), - device="cuda", + device=get_device_type(), cache_dir=args.model_download_dir, trust_remote_code=args.trust_remote_code, **target_model_kwargs, @@ -463,11 +466,11 @@ def generate( output_path, current_batch_indices ) exists_tensor = torch.tensor( - exists_list, dtype=torch.bool, device="cuda" + exists_list, dtype=torch.bool, device=get_local_device() ) else: exists_tensor = torch.tensor( - [False] * batch_size, dtype=torch.bool, device="cuda" + [False] * batch_size, dtype=torch.bool, device=get_local_device() ) dist.broadcast(exists_tensor, src=tp_rank_0_global, group=tp_group) @@ -504,7 +507,8 @@ def generate( continue filtered_batch_gpu = { - k: v.cuda(non_blocking=True) for k, v in filtered_batch.items() + k: v.to(get_local_device(), non_blocking=True) + for k, v in filtered_batch.items() } _, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend( **filtered_batch_gpu, @@ -559,7 +563,7 @@ def generate( del aux_hidden_states_list, last_hidden_states_list, filtered_batch if batch_idx % 5 == 0: # Make GC and cache clearing more frequent - torch.cuda.empty_cache() + empty_cache() gc.collect() if self.show_progress: diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index cd22ec7c5..e1047c1ab 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -36,7 +36,12 @@ from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead from specforge.optimizer import BF16Optimizer from specforge.tracker import create_tracker -from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank +from specforge.utils import ( + get_last_checkpoint, + get_local_device, + print_on_rank0, + print_with_rank, +) def parse_args(): @@ -159,11 +164,14 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: if args.target_model_backend == "sglang": target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + device = get_local_device() + device_type = device.type + target_model = get_dflash_target_model( pretrained_model_name_or_path=args.target_model_path, backend=args.target_model_backend, torch_dtype=torch.bfloat16, - device="cuda" if args.target_model_backend == "hf" else None, + device=device_type if args.target_model_backend == "hf" else None, trust_remote_code=args.trust_remote_code, **target_model_kwargs, ) @@ -194,7 +202,7 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: draft_config._attn_implementation = args.attention_backend print_on_rank0(f"Using attention backend: {args.attention_backend}") - draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) + draft_model = DFlashDraftModel(draft_config).to(device=device, dtype=torch.bfloat16) target_model.set_capture_layers(draft_model.target_layer_ids) @@ -426,7 +434,7 @@ def main(): args.target_model_path, embed_key=args.embedding_key, lm_head_key=args.lm_head_key, - device="cuda", + device=device_type, trust_remote_code=args.trust_remote_code, ) @@ -522,13 +530,13 @@ def main(): continue global_step += 1 - input_ids = data["input_ids"].cuda() - attention_mask = data["attention_mask"].cuda() - loss_mask = data["loss_mask"].cuda() + input_ids = data["input_ids"].to(device, non_blocking=True) + attention_mask = data["attention_mask"].to(device, non_blocking=True) + loss_mask = data["loss_mask"].to(device, non_blocking=True) target_output = target_model.generate_dflash_data( input_ids, attention_mask, loss_mask ) - hidden_states = target_output.hidden_states.cuda() # Ensure on GPU + hidden_states = target_output.hidden_states.to(device, non_blocking=True) loss, accuracy = dflash_model( input_ids=input_ids, diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index f8865fbd5..4bcb0f21e 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -47,26 +47,36 @@ from specforge.tracker import Tracker, create_tracker, get_tracker_class from specforge.utils import ( create_draft_config_from_target, + current_device, + get_device_module, + get_device_type, get_last_checkpoint, + get_local_device, print_args_with_dots, print_on_rank0, print_with_rank, rank_0_priority, safe_conversations_generator, + synchronize, ) def print_cuda_memory_debug(label: str) -> None: - if os.getenv("SPECFORGE_CI_MEMORY_DEBUG") != "1" or not torch.cuda.is_available(): + device_type = get_device_type() + if os.getenv("SPECFORGE_CI_MEMORY_DEBUG") != "1" or device_type == "cpu": return try: - torch.cuda.synchronize() - free_bytes, total_bytes = torch.cuda.mem_get_info() - allocated_bytes = torch.cuda.memory_allocated() - reserved_bytes = torch.cuda.memory_reserved() + synchronize() + device_module = get_device_module() + free_bytes, total_bytes = device_module.mem_get_info() + allocated_bytes = device_module.memory_allocated() + reserved_bytes = device_module.memory_reserved() except Exception as exc: - print(f"[memory-debug] {label}: failed to query CUDA memory: {exc}", flush=True) + print( + f"[memory-debug] {label}: failed to query {device_type} memory: {exc}", + flush=True, + ) return rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else "NA" @@ -333,7 +343,7 @@ def build_target_model( torch_dtype=torch.bfloat16, ) .eval() - .cuda() + .to(get_local_device()) ) else: if args.target_model_backend == "sglang": @@ -344,7 +354,7 @@ def build_target_model( pretrained_model_name_or_path=args.target_model_path, backend=args.target_model_backend, torch_dtype=torch.bfloat16, - device="cuda", + device=get_device_type(), cache_dir=args.model_download_dir, **target_model_kwargs, trust_remote_code=args.trust_remote_code, @@ -472,13 +482,13 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] draft_model_last_checkpoint, attention_backend=args.attention_backend, torch_dtype=torch.bfloat16, - ).cuda() + ).to(get_local_device()) else: draft_model = AutoEagle3DraftModel.from_config( draft_model_config, attention_backend=args.attention_backend, torch_dtype=torch.bfloat16, - ).cuda() + ).to(get_local_device()) # Load training state (optimizer, scheduler, epoch, step) for true resume resume_state = None @@ -680,38 +690,41 @@ def run_forward( metric_losses, metric_loss_denoms, ) = eagle3_model( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].cuda(), - pixel_values=data["pixel_values"].cuda(), - image_grid_thw=data["image_grid_thw"].cuda(), + input_ids=data["input_ids"].to(get_local_device()), + attention_mask=data["attention_mask"].to(get_local_device()), + loss_mask=data["loss_mask"].to(get_local_device()), + pixel_values=data["pixel_values"].to(get_local_device()), + image_grid_thw=data["image_grid_thw"].to(get_local_device()), ) else: image_grid_thw = None if is_online: # we generate the eagle3 using the target model in an online fashion # Handle VLM data: pixel_values and image_grid_thw are lists - # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None + # pixel_values = [pv.to(get_local_device()) for pv in data["pixel_values"]] if args.is_vlm else None if args.is_vlm: image_grid_thw = ( - [thw.cuda().squeeze() for thw in data["image_grid_thw"]] + [ + thw.to(get_local_device()).squeeze() + for thw in data["image_grid_thw"] + ] if args.is_vlm else None ) - pixel_values = data["pixel_values"].cuda() + pixel_values = data["pixel_values"].to(get_local_device()) eagle3_data = target_model.generate_eagle3_data( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].cuda(), + input_ids=data["input_ids"].to(get_local_device()), + attention_mask=data["attention_mask"].to(get_local_device()), + loss_mask=data["loss_mask"].to(get_local_device()), is_vlm=args.is_vlm, pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) else: eagle3_data = target_model.generate_eagle3_data( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].cuda(), + input_ids=data["input_ids"].to(get_local_device()), + attention_mask=data["attention_mask"].to(get_local_device()), + loss_mask=data["loss_mask"].to(get_local_device()), shard_returns=args.shard_target_output, ) @@ -732,16 +745,16 @@ def run_forward( ) else: # we generate the logits using the hidden states loaded from disk - attention_mask = data["attention_mask"].cuda() - hidden_states = data["hidden_state"].cuda() + attention_mask = data["attention_mask"].to(get_local_device()) + hidden_states = data["hidden_state"].to(get_local_device()) input_ids, target, loss_mask = target_model.preprocess( data["input_ids"], data["target"], data["loss_mask"] ) - input_ids = input_ids.cuda() + input_ids = input_ids.to(get_local_device()) target = target_model( - target.cuda() + target.to(get_local_device()) ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. - loss_mask = loss_mask.cuda() + loss_mask = loss_mask.to(get_local_device()) ( plosses, acceptance_rates, @@ -757,7 +770,9 @@ def run_forward( target=target, hidden_states=hidden_states, position_ids=( - data["position_ids"].cuda() if "position_ids" in data else None + data["position_ids"].to(get_local_device()) + if "position_ids" in data + else None ), image_grid_thw=image_grid_thw, is_vlm=args.is_vlm, @@ -787,8 +802,8 @@ def run_backward_and_update( grad_norm = optimizer.step() if dist.is_initialized(): grad_norm = grad_norm.detach().float() - if torch.cuda.is_available(): - grad_norm = grad_norm.to(torch.cuda.current_device()) + if get_device_type() != "cpu": + grad_norm = grad_norm.to(current_device()) grad_norm = grad_norm.pow(2) dist.all_reduce(grad_norm, op=dist.ReduceOp.SUM) grad_norm = grad_norm.sqrt() diff --git a/specforge/benchmarks/benchmark_flex_attention.py b/specforge/benchmarks/benchmark_flex_attention.py index 20f989565..c535879fc 100644 --- a/specforge/benchmarks/benchmark_flex_attention.py +++ b/specforge/benchmarks/benchmark_flex_attention.py @@ -13,6 +13,13 @@ LlamaFlexAttention, prepare_decoder_attention_mask, ) +from specforge.utils import ( + empty_cache, + get_device_module, + get_device_type, + get_local_device, + synchronize, +) dynamo.config.recompile_limit = 64 @@ -40,7 +47,7 @@ def run_attention( attention_backend: str = "sdpa", enable_profile: bool = False, ): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_local_device() batch_size = hidden_states_list[0].shape[0] # Initialize cache and attention function based on backend if attention_backend == "sdpa": @@ -133,9 +140,9 @@ def benchmark_function( print(f"\nTesting sequence length: {seq_len}") # Clear GPU cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + if get_device_type() != "cpu": + empty_cache() + get_device_module().reset_peak_memory_stats() # Warm up runs for this sequence length if enable_warmup: @@ -147,27 +154,27 @@ def benchmark_function( seq_len, HIDDEN_SIZE, requires_grad=True, - device="cuda", + device=get_local_device(), dtype=torch.bfloat16, ) for _ in range(TTT_LENGTH) ] run_attention(seq_len, hidden_states, attention_backend) # Clear cache again after warmup - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + if get_device_type() != "cpu": + empty_cache() + get_device_module().reset_peak_memory_stats() # Record initial memory initial_memory = 0 - if torch.cuda.is_available(): - initial_memory = torch.cuda.memory_allocated() + if get_device_type() != "cpu": + initial_memory = get_device_module().memory_allocated() hidden_states = [ torch.randn( BATCH_SIZE, seq_len, HIDDEN_SIZE, requires_grad=True, - device="cuda", + device=get_local_device(), dtype=torch.bfloat16, ) for _ in range(TTT_LENGTH) @@ -179,16 +186,16 @@ def benchmark_function( attention_backend, enable_profile and seq_len == seq_lengths[0], ) - if torch.cuda.is_available(): - torch.cuda.synchronize() + if get_device_type() != "cpu": + synchronize() end_time = time.time() # Record memory usage peak_memory = 0 current_memory = 0 - if torch.cuda.is_available(): - peak_memory = torch.cuda.max_memory_allocated() - current_memory = torch.cuda.memory_allocated() + if get_device_type() != "cpu": + peak_memory = get_device_module().max_memory_allocated() + current_memory = get_device_module().memory_allocated() results_per_seq_len.append( { "seq_len": seq_len, @@ -292,16 +299,18 @@ def plot_results(eagle_results, flex_results, seq_lengths): args = parser.parse_args() print("PyTorch version:", torch.__version__) - if torch.cuda.is_available(): - print("CUDA available:", torch.cuda.is_available()) - print("GPU:", torch.cuda.get_device_name()) + device_type = get_device_type() + if device_type != "cpu": + device_module = get_device_module() + print(f"{device_type} available: True") + print("Device:", device_module.get_device_name()) print( - "GPU memory:", - torch.cuda.get_device_properties(0).total_memory / 1024**3, + "Device memory:", + device_module.get_device_properties(0).total_memory / 1024**3, "GB", ) else: - print("CUDA not available - running on CPU") + print("No accelerator available - running on CPU") # Define sequence lengths to test seq_lengths = [128 * i for i in range(1, 28, 4)] diff --git a/specforge/benchmarks/benchmark_loss.py b/specforge/benchmarks/benchmark_loss.py index 940787a86..13dd792b3 100644 --- a/specforge/benchmarks/benchmark_loss.py +++ b/specforge/benchmarks/benchmark_loss.py @@ -3,7 +3,14 @@ import torch -from specforge.core.loss import LogSoftmaxLoss, _compute_loss +from specforge.core.loss import _compute_loss, log_softmax_loss +from specforge.utils import ( + empty_cache, + get_device_module, + get_device_type, + get_local_device, + synchronize, +) TTT_LENGTH = 7 @@ -22,32 +29,39 @@ def benchmark_loss_method( print(f"\nTesting config: B={B}, T={T}, V={V}") # Clear GPU cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + if get_device_type() != "cpu": + empty_cache() + get_device_module().reset_peak_memory_stats() # Create tensors outside timing measurement target = torch.softmax( - torch.randn(B, T, V, device="cuda", dtype=torch.float32), dim=-1 + torch.randn(B, T, V, device=get_local_device(), dtype=torch.float32), dim=-1 + ) + position_mask = torch.ones( + (B, T, 1), dtype=torch.bool, device=get_local_device() ) - position_mask = torch.ones((B, T, 1), dtype=torch.bool, device="cuda") # Pre-allocate logits tensors for each TTT step logits_list = [] for i in range(TTT_LENGTH): logits = torch.randn( - B, T, V, device="cuda", requires_grad=True, dtype=torch.float32 + B, + T, + V, + device=get_local_device(), + requires_grad=True, + dtype=torch.float32, ) logits_list.append(logits) - torch.cuda.synchronize() # Ensure all operations are complete + synchronize() start_time = time.time() plosses = [] for i in range(TTT_LENGTH): logits = logits_list[i] if loss_method == "triton": - loss = LogSoftmaxLoss.apply(logits, target, position_mask) + loss = log_softmax_loss(logits, target, position_mask) else: loss = _compute_loss(logits, target, position_mask) plosses.append(loss) @@ -59,15 +73,14 @@ def benchmark_loss_method( ) ploss.backward() - if torch.cuda.is_available(): - torch.cuda.synchronize() + if get_device_type() != "cpu": + synchronize() end_time = time.time() total_time = end_time - start_time - # Record memory usage peak_memory = 0 - if torch.cuda.is_available(): - peak_memory = torch.cuda.max_memory_allocated() + if get_device_type() != "cpu": + peak_memory = get_device_module().max_memory_allocated() results.append( { @@ -93,16 +106,17 @@ def main(): args = parser.parse_args() print("PyTorch version:", torch.__version__) - if torch.cuda.is_available(): - print("CUDA available:", torch.cuda.is_available()) - print("GPU:", torch.cuda.get_device_name()) + device_type = get_device_type() + if device_type != "cpu": + print(f"{device_type} available: True") + print("Device:", get_device_module().get_device_name()) print( - "GPU memory:", - torch.cuda.get_device_properties(0).total_memory / 1024**3, + "Device memory:", + get_device_module().get_device_properties(0).total_memory / 1024**3, "GB", ) else: - print("CUDA not available - running on CPU") + print("No accelerator available - running on CPU") # Define test configurations (B, T, V) test_configs = [ diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 9151e8e33..9e59d0985 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -29,9 +29,9 @@ from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter from specforge.core.lk_loss import compute_acceptance_rate, compute_lk_loss -from specforge.core.loss import LogSoftmaxLoss +from specforge.core.loss import log_softmax_loss from specforge.modeling.draft import Eagle3DraftModel -from specforge.utils import padding +from specforge.utils import empty_cache, get_compile_backend, padding class Eagle3Model(nn.Module): @@ -65,7 +65,7 @@ def _compute_loss_and_acceptance_rate( reduce_metrics_fn: Optional distributed reducer for metric numer/denom. reduce_loss_fn: Optional distributed reducer for KL loss. """ - kl_loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) + kl_loss = log_softmax_loss(logits, target_p, position_mask) if reduce_loss_fn is not None: kl_loss = reduce_loss_fn(kl_loss) @@ -272,7 +272,7 @@ def forward( length=self.length, ) del target - torch.cuda.empty_cache() + empty_cache() # basic info batch_size, seq_length, _ = hidden_states.shape @@ -806,7 +806,7 @@ def _compute_target_p_padded(target, t2d, loss_mask, length): ) -@torch.compile(dynamic=None) +@torch.compile(dynamic=None, backend=get_compile_backend()) def _compute_target_p(target, t2d, loss_mask): target_head = target.float() target_token_ids = target_head.argmax(-1) @@ -823,13 +823,13 @@ def _compute_target_p(target, t2d, loss_mask): return target_p, target_p_on_draft, target_token_ids, position_mask -@torch.compile(dynamic=None) +@torch.compile(dynamic=None, backend=get_compile_backend()) def _compute_metric_acc(logits, target_token_ids, loss_mask, d2t): correct, denom = _compute_metric_counts(logits, target_token_ids, loss_mask, d2t) return correct / denom -@torch.compile(dynamic=None) +@torch.compile(dynamic=None, backend=get_compile_backend()) def _compute_metric_counts(logits, target_token_ids, loss_mask, d2t): pred_draft_token_ids = logits.argmax(-1) pred_target_token_ids = pred_draft_token_ids + d2t[pred_draft_token_ids] diff --git a/specforge/core/loss.py b/specforge/core/loss.py index 30e7fba7d..6fe2f9c56 100644 --- a/specforge/core/loss.py +++ b/specforge/core/loss.py @@ -10,9 +10,11 @@ import triton import triton.language as tl +from specforge.utils import get_compile_backend, get_device_type + # Reference implementation -@torch.compile(dynamic=None) +@torch.compile(dynamic=None, backend=get_compile_backend()) def _compute_loss(logits, target_p, position_mask): logits = logits.float() out_logp = nn.LogSoftmax(dim=2)(logits) @@ -228,15 +230,197 @@ def backward(ctx, grad_output): return logits, None, None, None, None +def _calculate_settings_npu(n): + NPU_MAX_BLOCK_SIZE = 4096 + BLOCK_SIZE = min(triton.next_power_of_2(n), NPU_MAX_BLOCK_SIZE) + return BLOCK_SIZE + + +@triton.jit +def log_softmax_forward_kernel_npu( + logits_ptr, + logits_stride, + target_ptr, + target_stride, + position_mask_ptr, + position_mask_stride, + loss_ptr, + loss_stride, + m_ptr, + d_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + program_id = tl.program_id(0).to(tl.int64) + logits_ptr += program_id * logits_stride + target_ptr += program_id * target_stride + position_mask_ptr += program_id * position_mask_stride + position_mask = tl.load(position_mask_ptr) + if position_mask == 0: + return + + m = float("-inf") + d = 0.0 + sum_target_logits = 0.0 + sum_target = 0.0 + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load( + logits_ptr + offsets, mask=mask, other=float("-inf"), care_padding=False + ).cast(tl.float32) + target_block = tl.load( + target_ptr + offsets, mask=mask, other=0.0, care_padding=False + ).cast(tl.float32) + + block_max = tl.max(tl.where(mask, logits_block, float("-inf"))) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum( + tl.where(mask, tl.exp(logits_block - m_new), 0.0) + ) + m = m_new + + sum_target_logits += tl.sum(tl.where(mask, target_block * logits_block, 0.0)) + sum_target += tl.sum(tl.where(mask, target_block, 0.0)) + + loss = -(sum_target_logits - sum_target * (m + tl.log(d))) + + loss_ptr += program_id * loss_stride + m_ptr += program_id + d_ptr += program_id + tl.store(loss_ptr, loss) + tl.store(m_ptr, m.to(tl.float32)) + tl.store(d_ptr, d.to(tl.float32)) + + +@triton.jit +def log_softmax_backward_kernel_npu( + logits_ptr, + logits_stride, + target_ptr, + target_stride, + position_mask_ptr, + target_grad_sum_ptr, + m_ptr, + d_ptr, + grad_output_scaled, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + program_id = tl.program_id(0).to(tl.int64) + logits_ptr += program_id * logits_stride + target_ptr += program_id * target_stride + position_mask_ptr += program_id + + position_mask = tl.load(position_mask_ptr) + if position_mask == 0: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + tl.store(logits_ptr + offsets, 0.0, mask=mask) + return + + m_ptr += program_id + d_ptr += program_id + target_grad_sum_ptr += program_id + m = tl.load(m_ptr).to(tl.float32) + d = tl.load(d_ptr).to(tl.float32) + target_grad_sum = tl.load(target_grad_sum_ptr).to(tl.float32) + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load( + logits_ptr + offsets, mask=mask, other=0.0, care_padding=False + ).cast(tl.float32) + target_block = tl.load( + target_ptr + offsets, mask=mask, other=0.0, care_padding=False + ).cast(tl.float32) + softmax_prob = tl.exp(logits_block - m) / d + normalized_grad = softmax_prob * target_grad_sum + grad_block = -(target_block * grad_output_scaled - normalized_grad) + tl.store(logits_ptr + offsets, grad_block.to(tl.float32), mask=mask) + + +class LogSoftmaxLossNPU(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, target, position_mask): + B, T, V = logits.shape + loss = torch.zeros((B * T, 1), device=logits.device) + logits_flat = logits.contiguous().view(B * T, V) + target_flat = target.contiguous().view(B * T, V) + position_mask_flat = position_mask.contiguous().view(B * T, 1).bool() + grid = (B * T,) + m = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) + d = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) + BLOCK_SIZE = _calculate_settings_npu(V) + log_softmax_forward_kernel_npu[grid]( + logits_flat, + logits_flat.stride(0), + target_flat, + target_flat.stride(0), + position_mask_flat, + position_mask_flat.stride(0), + loss, + loss.stride(0), + m, + d, + V, + BLOCK_SIZE=BLOCK_SIZE, + ) + ctx.save_for_backward(logits.detach(), target, position_mask, m, d) + return loss.squeeze(1).mean() + + @staticmethod + def backward(ctx, grad_output): + logits, target, position_mask, m, d = ctx.saved_tensors + B, T, V = logits.shape + logits = logits.contiguous().view(B * T, V) + target = target.contiguous().view(B * T, V) + position_mask = position_mask.contiguous().view(B * T, 1).bool() + + scaling_factor = 1.0 / (B * T) + grad_output_scaled = grad_output.item() * scaling_factor + target_sum_per_row = target.sum(dim=-1) + target_grad_sum_per_row = target_sum_per_row * grad_output_scaled + + grid = (B * T,) + BLOCK_SIZE = _calculate_settings_npu(V) + log_softmax_backward_kernel_npu[grid]( + logits, + logits.stride(0), + target, + target.stride(0), + position_mask, + target_grad_sum_per_row, + m, + d, + grad_output_scaled, + V, + BLOCK_SIZE=BLOCK_SIZE, + ) + + logits = logits.view(B, T, V) + return logits, None, None + + +def log_softmax_loss(logits, target, position_mask): + if get_device_type() == "npu": + return LogSoftmaxLossNPU.apply(logits, target, position_mask) + else: + return LogSoftmaxLoss.apply(logits, target, position_mask) + + if __name__ == "__main__": - device = "cuda" + device = get_device_type() B, T, V = 1, 1024, 16000 logits = torch.randn(B, T, V, device=device, requires_grad=True) logits2 = logits.clone().detach().requires_grad_(True) target = torch.randn(B, T, V, device=device) position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device) position_mask = torch.ones((B, T, 1), dtype=torch.bool, device=device) - output1 = LogSoftmaxLoss.apply(logits, target, position_mask) + output1 = log_softmax_loss(logits, target, position_mask) output2 = _compute_loss(logits2, target, position_mask) torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) output1.backward() diff --git a/specforge/data/parse.py b/specforge/data/parse.py index b9a7cccdf..aa83817cd 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -178,19 +178,19 @@ def parse( warnings.warn( f"Conversation must start with a 'user' role, but found '{role}'. Conversation truncated." ) - break + return None, None else: prev_role = conversation[j - 1]["role"] if role == "tool" and prev_role not in ["assistant", "tool"]: warnings.warn( f"A 'tool' message must follow an 'assistant' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." ) - break + return None, None if role == "assistant" and prev_role not in ["user", "tool"]: warnings.warn( f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." ) - break + return None, None sentence = self._sanitize_message(sentence) messages.append(sentence) try: diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 1dcd41d3b..907468ed7 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -171,6 +171,9 @@ def preprocess_conversations( tool=tool, **kwargs_item, ) + if input_ids is None or loss_mask is None: + # if parsing failed, skip this conversation + continue results["input_ids"].append(input_ids[None, :]) results["loss_mask"].append(loss_mask[None, :]) results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) diff --git a/specforge/distributed.py b/specforge/distributed.py index fb5e882c4..6f45a5f94 100644 --- a/specforge/distributed.py +++ b/specforge/distributed.py @@ -5,7 +5,13 @@ import torch.distributed as dist from yunchang.globals import PROCESS_GROUP, set_seq_parallel_pg -from specforge.utils import print_with_rank +from specforge.utils import ( + device_count, + get_device_type, + get_dist_backend, + print_with_rank, + set_device, +) _DEVICE_MESH = None _TP_DEVICE_MESH = None @@ -72,9 +78,11 @@ def init_distributed( timeout(int): Timeout for collective communication in minutes tp_size(int): The degree of tensor parallelism """ - dist.init_process_group(backend="nccl", timeout=timedelta(minutes=timeout)) - local_rank = dist.get_rank() % torch.cuda.device_count() - torch.cuda.set_device(local_rank) + dist.init_process_group( + backend=get_dist_backend(), timeout=timedelta(minutes=timeout) + ) + local_rank = dist.get_rank() % device_count() + set_device(local_rank) print_with_rank(f"bind to device {local_rank}") world_size = dist.get_world_size() @@ -84,7 +92,7 @@ def init_distributed( ), f"world size must be divisible by tp size, now {world_size=}, {(tp_size * dp_size)=} " device_mesh = dist.device_mesh.init_device_mesh( - "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + get_device_type(), (dp_size, tp_size), mesh_dim_names=("dp", "tp") ) assert ( @@ -93,7 +101,7 @@ def init_distributed( draft_dp_size = world_size // (sp_ulysses_size * sp_ring_size) draft_device_mesh = dist.device_mesh.init_device_mesh( - "cuda", + get_device_type(), (draft_dp_size, sp_ulysses_size * sp_ring_size), mesh_dim_names=("draft_dp", "sp"), ) @@ -106,7 +114,7 @@ def init_distributed( sp_ulysses_group = PROCESS_GROUP.ULYSSES_PG sp_ring_group = PROCESS_GROUP.RING_PG # we need to create a 1D submesh - tp_device_mesh = dist.DeviceMesh.from_group(tp_group, device_type="cuda") + tp_device_mesh = dist.DeviceMesh.from_group(tp_group, device_type=get_device_type()) global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH, _DP_DEVICE_MESH, _SP_RING_GROUP, _SP_ULYSSES_GROUP, _DRAFT_DP_GROUP, _DRAFT_SP_GROUP _DEVICE_MESH = device_mesh @@ -117,7 +125,9 @@ def init_distributed( _DP_GROUP = dp_group _DRAFT_DP_GROUP = draft_device_mesh.get_group("draft_dp") _DRAFT_SP_GROUP = draft_device_mesh.get_group("sp") - _DP_DEVICE_MESH = dist.DeviceMesh.from_group(dp_group, device_type="cuda") + _DP_DEVICE_MESH = dist.DeviceMesh.from_group( + dp_group, device_type=get_device_type() + ) def destroy_distributed(): diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py index 1e48a43e7..515f08453 100644 --- a/specforge/modeling/auto.py +++ b/specforge/modeling/auto.py @@ -18,6 +18,8 @@ modeling_utils, ) +from specforge.utils import get_local_device + from .draft.llama3_eagle import LlamaForCausalLMEagle3 from .target.custom_backend import ( GptOssForCausalLM, @@ -125,7 +127,7 @@ def from_pretrained( if device is not None: model = model.to(device) else: - model = model.cuda() + model = model.to(get_local_device()) return model diff --git a/specforge/modeling/draft/flex_attention.py b/specforge/modeling/draft/flex_attention.py index 50ca5f54d..5fb07109d 100644 --- a/specforge/modeling/draft/flex_attention.py +++ b/specforge/modeling/draft/flex_attention.py @@ -7,6 +7,8 @@ ) from transformers.utils import is_torchdynamo_compiling +from specforge.utils import get_compile_backend + dynamo.config.recompile_limit = 64 @@ -35,7 +37,7 @@ def __init__(self): # Enable dynamic shapes to handle different input sizes self._compiled_flex_attention = torch.compile( flex_attention, - # mode="max-autotune-no-cudagraphs", + backend=get_compile_backend(), ) self._is_flex_compiled = True @@ -75,7 +77,9 @@ def __new__(cls, *args, **kwargs): @torch.compiler.disable(recursive=False) def __init__(self): if not self._is_create_block_mask_compiled: - self._compiled_create_block_mask = torch.compile(create_block_mask) + self._compiled_create_block_mask = torch.compile( + create_block_mask, backend=get_compile_backend() + ) self._is_create_block_mask_compiled = True def __call__(self): diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index e96286bdf..267483d13 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -17,7 +17,7 @@ compile_friendly_flex_attention, generate_eagle3_mask, ) -from specforge.utils import print_with_rank +from specforge.utils import get_compile_backend, print_with_rank from ...distributed import get_sp_ring_group, get_sp_ulysses_group from .base import Eagle3DraftModel @@ -112,7 +112,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -@torch.compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compile_backend()) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] @@ -282,7 +282,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False ) - @torch.compile(dynamic=True) + @torch.compile(dynamic=True, backend=get_compile_backend()) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len and seq_len > self.max_seq_len_cached: @@ -1532,7 +1532,7 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - @torch.compile(dynamic=True) + @torch.compile(dynamic=True, backend=get_compile_backend()) def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 0df938239..ea04de53b 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -18,6 +18,7 @@ from transformers import AutoModelForCausalLM from specforge.distributed import get_tp_group +from specforge.utils import current_device from .sglang_backend import SGLangRunner @@ -98,7 +99,7 @@ def from_pretrained( model_runner = SGLangRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, - gpu_id=torch.cuda.current_device(), + gpu_id=current_device(), tp_rank=dist.get_rank(get_tp_group()), tp_size=server_args.tp_size, moe_ep_rank=moe_ep_rank, diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index d51a9cdc9..c9e3efcf7 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -34,7 +34,7 @@ from transformers import AutoModelForCausalLM from specforge.distributed import get_tp_device_mesh, get_tp_group -from specforge.utils import padding +from specforge.utils import current_device, padding from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module from .sglang_backend.utils import LogitsProcessorForEAGLE3 @@ -334,7 +334,7 @@ def from_pretrained( model_runner = SGLangRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, - gpu_id=torch.cuda.current_device(), + gpu_id=current_device(), tp_rank=dist.get_rank(get_tp_group()), tp_size=server_args.tp_size, moe_ep_rank=moe_ep_rank, diff --git a/specforge/modeling/target/target_head.py b/specforge/modeling/target/target_head.py index 86ab4f501..944d6dd08 100644 --- a/specforge/modeling/target/target_head.py +++ b/specforge/modeling/target/target_head.py @@ -9,7 +9,7 @@ from safetensors import safe_open from transformers import AutoConfig -from specforge.utils import padding +from specforge.utils import get_local_device, padding class TargetHead(nn.Module): @@ -40,7 +40,7 @@ def from_pretrained( cache_dir=cache_dir, ) target_head.freeze_weights() - target_head = target_head.eval().cuda().to(torch.bfloat16) + target_head = target_head.eval().to(get_local_device()).to(torch.bfloat16) return target_head @torch.no_grad() diff --git a/specforge/modeling/target/target_utils.py b/specforge/modeling/target/target_utils.py index 9dacba6be..66f2a08d8 100644 --- a/specforge/modeling/target/target_utils.py +++ b/specforge/modeling/target/target_utils.py @@ -10,6 +10,8 @@ from safetensors import safe_open from transformers import AutoConfig +from specforge.utils import get_device_type + class TargetEmbeddingsAndHead(nn.Module): """ @@ -45,12 +47,13 @@ def from_pretrained( embed_key: Optional[str] = None, lm_head_key: Optional[str] = None, cache_dir: Optional[str] = None, - device: str = "cuda", + device: str = None, dtype: torch.dtype = torch.bfloat16, trust_remote_code: bool = False, ) -> "TargetEmbeddingsAndHead": - # 1. Load Config + if device is None: + device = get_device_type() config = AutoConfig.from_pretrained( model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code ) diff --git a/specforge/utils.py b/specforge/utils.py index af4d627c8..e3d363d83 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -12,6 +12,13 @@ logger = logging.getLogger(__name__) +def get_compile_backend() -> str: + # now ascend npu can not support inductor for backend, + if hasattr(torch, "npu") and torch.npu.is_available(): + return "eager" + return "inductor" + + @contextmanager def rank_0_priority(): rank = dist.get_rank() @@ -49,6 +56,74 @@ def load_config_from_file(config_path: str): return PretrainedConfig.from_dict(config) +def get_device_type() -> str: + """Auto-detect the available accelerator type. + + Priority: + 1. SPECFORGE_DEVICE environment variable + 2. NVIDIA CUDA (torch.cuda) + 3. Ascend NPU (torch.npu) + 4. CPU fallback + """ + dt = os.environ.get("SPECFORGE_DEVICE", None) + if dt: + return dt + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch, "npu") and torch.npu.is_available(): + return "npu" + return "cpu" + + +def get_local_device() -> torch.device: + """Return the local torch.device for the current process rank.""" + device_type = get_device_type() + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if device_type in ("cuda", "npu"): + return torch.device(device_type, local_rank) + return torch.device("cpu") + + +def get_device_module(): + """Return the torch device module for the current device type (cuda, npu, etc.).""" + return torch.get_device_module(get_device_type()) + + +def empty_cache(): + """Empty the cache for the current accelerator device.""" + get_device_module().empty_cache() + + +def device_count() -> int: + """Return the number of available accelerator devices.""" + return get_device_module().device_count() + + +def current_device() -> int: + """Return the current accelerator device index.""" + return get_device_module().current_device() + + +def set_device(device_id: int): + """Set the current accelerator device.""" + get_device_module().set_device(device_id) + + +def synchronize(): + """Synchronize the current accelerator device.""" + get_device_module().synchronize() + + +def get_dist_backend() -> str: + """Return the appropriate distributed backend for the current device type.""" + device_type = get_device_type() + if device_type == "cuda": + return "nccl" + if device_type == "npu": + return "hccl" + return "gloo" + + def print_with_rank(message): if dist.is_available() and dist.is_initialized(): logger.info(f"rank {dist.get_rank()}: {message}")