Skip to content

feat: optimize fused linear JSD memory and runtime#1261

Open
d7chong wants to merge 3 commits into
linkedin:mainfrom
d7chong:fused-linear-jsd
Open

feat: optimize fused linear JSD memory and runtime#1261
d7chong wants to merge 3 commits into
linkedin:mainfrom
d7chong:fused-linear-jsd

Conversation

@d7chong

@d7chong d7chong commented Jun 14, 2026

Copy link
Copy Markdown

Summary

A faster and more memory-efficient fused linear JSD kernel.

Testing Done

  • Hardware Type: NVIDIA A100-SXM4-80GB
  • Default Tests
    • run make test to ensure correctness
    • run make checkstyle to ensure code style
    • run make test-convergence to ensure convergence

Additional Testing

  • Additional testing was to compare (1) torch (2) old fused-linear-jsd (3) new fused-linear-jsd
  • Blank spaces are empty due to OOM

Forward/Backward Correctness

================================================================================
CORRECTNESS TEST RESULTS (backward_mode=full)
================================================================================
Tolerance: rtol=0.01, atol=0.001
BT         Provider     Output Diff     Output   Grad Match  
--------------------------------------------------------------------------------
16         old_liger    1.739502e-03    PASS     PASS        
16         new_liger    1.736104e-03    PASS     PASS        
32         old_liger    1.737237e-03    PASS     PASS        
32         new_liger    1.737297e-03    PASS     PASS        
64         old_liger    1.738489e-03    PASS     PASS        
64         new_liger    1.738429e-03    PASS     PASS        
128        old_liger    1.739264e-03    PASS     PASS        
128        new_liger    1.739264e-03    PASS     PASS        
256        old_liger    1.739025e-03    PASS     PASS        
256        new_liger    1.739621e-03    PASS     PASS        
512        old_liger    1.738846e-03    PASS     PASS        
512        new_liger    1.739144e-03    PASS     PASS        
1024       old_liger    1.738966e-03    PASS     PASS        
1024       new_liger    1.739025e-03    PASS     PASS        
2048       old_liger    1.723111e-03    PASS     PASS        
2048       new_liger    1.723170e-03    PASS     PASS        
--------------------------------------------------------------------------------
CORRECTNESS TEST: PASSED
================================================================================

Forward Speed (p50 ms)

BT Torch Old New New vs Old New vs Torch
1024 45.17 437.75 80.66 5.43x 0.56x
4096 180.96 613.93 319.99 1.92x 0.57x
16384 - 1501.66 1274.87 1.18x -
65536 - - 5095.01 - -
262144 - - 20420.69 - -

Backward Full Speed (p50 ms)

BT Torch Old New New vs Old New vs Torch
1024 42.60 2.47 2.47 1.00x 17.27x
4096 172.04 2.50 2.49 1.00x 69.01x
16384 - 2.60 2.61 1.00x -
65536 - - 3.07 - -
262144 - - 4.91 - -

Full Pass Speed (p50 ms)

BT Torch Old New New vs Old New vs Torch
1024 88.14 440.46 82.97 5.31x 1.06x
4096 358.12 616.71 322.55 1.91x 1.11x
16384 - 1506.62 1280.81 1.18x -
65536 - - 5110.77 - -
262144 - - 20439.55 - -

Peak Memory (p50 MB)

BT Torch Old New Torch/New
1024 27176 26184 26184 1.04x
4096 60224 26256 26256 2.29x
16384 - 38294 26544 -
65536 - - 27696 -
262144 - - 32400 -

Testing script

  • Comparison script: benchmark/script/compare_jsd.py
  • Original fused_jsd: src/liger_kernel/ops/fused_linear_jsd_old.py
benchmarking code
import argparse
import json

from typing import Callable

import torch
import triton

from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction as NewLigerFusedLinearJSDFunction
from liger_kernel.ops.fused_linear_jsd_old import LigerFusedLinearJSDFunction as OldLigerFusedLinearJSDFunction
from liger_kernel.utils import infer_device

device = infer_device()


class TorchJSD(torch.nn.Module):
    def __init__(
        self,
        beta: float = 0.5,
        ignore_index: int = -100,
        dtype: torch.dtype = torch.float,
    ):
        super().__init__()
        self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
        self.beta = beta
        self.ignore_index = ignore_index
        self.dtype = dtype

    def forward(self, log_q: torch.Tensor, log_p: torch.Tensor, label=None):
        log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
        log_p = log_p.view(-1, log_p.size(-1))
        log_q = log_q.view(-1, log_q.size(-1))
        m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
        loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl(
            torch.log(m), log_q
        ).sum(dim=-1)

        if label is not None:
            loss = torch.where(label != self.ignore_index, loss, 0.0)
            n_non_ignore = (label != self.ignore_index).sum().item()
            if n_non_ignore == 0:
                loss = torch.zeros((), device=log_q.device, dtype=log_q.dtype)
            else:
                loss = (loss / n_non_ignore).sum()
        else:
            loss = (loss / log_q.shape[0]).sum()
        return loss.to(self.dtype)


class TorchLMHeadJSD(torch.nn.Module):
    def __init__(
        self,
        H: int,
        V: int,
        dtype: torch.dtype,
        device: torch.device,
        beta: float = 0.5,
        ignore_index: int = -100,
        temperature: float = 1.0,
    ):
        super().__init__()
        self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
        self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
        self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype)
        self.temperature = temperature

    def forward(self, student_input, teacher_input, label=None):
        student_logits = self.student_lin(student_input).to(torch.float32)
        teacher_logits = self.teacher_lin(teacher_input).to(torch.float32)
        student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1)
        return self.jsd(student_prob, teacher_prob, label)


CORRECTNESS_TOLERANCE = {
    "bf16": {"rtol": 1e-2, "atol": 1e-3},
    "fp16": {"rtol": 1e-2, "atol": 1e-3},
    "fp32": {"rtol": 1e-4, "atol": 1e-5},
}


DTYPE_MAP = {
    "bf16": torch.bfloat16,
    "bfloat16": torch.bfloat16,
    "fp16": torch.float16,
    "float16": torch.float16,
    "fp32": torch.float32,
    "float32": torch.float32,
}


class OldLigerLMHeadJSD(torch.nn.Module):
    def __init__(
        self,
        H: int,
        V: int,
        dtype: torch.dtype,
        device: torch.device,
        beta: float = 0.5,
        ignore_index: int = -100,
        temperature: float = 1.0,
    ):
        super().__init__()
        self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
        self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
        self.beta = beta
        self.ignore_index = ignore_index
        self.temperature = temperature

    def forward(self, student_input, teacher_input, label=None):
        return OldLigerFusedLinearJSDFunction.apply(
            student_input,
            self.student_lin.weight,
            teacher_input,
            self.teacher_lin.weight,
            label,
            self.beta,
            self.ignore_index,
            self.temperature,
        )


class NewLigerLMHeadJSD(torch.nn.Module):
    def __init__(
        self,
        H: int,
        V: int,
        dtype: torch.dtype,
        device: torch.device,
        beta: float = 0.5,
        ignore_index: int = -100,
        temperature: float = 1.0,
    ):
        super().__init__()
        self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
        self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
        self.beta = beta
        self.ignore_index = ignore_index
        self.temperature = temperature

    def forward(self, student_input, teacher_input, label=None):
        return NewLigerFusedLinearJSDFunction.apply(
            student_input,
            self.student_lin.weight,
            teacher_input,
            self.teacher_lin.weight,
            label,
            self.beta,
            self.ignore_index,
            self.temperature,
        )


def parse_csv_list(value: str, caster: Callable = str):
    return [caster(item.strip()) for item in value.split(",") if item.strip()]


def make_model(provider, H, V, dtype, torch_device, beta, ignore_index, temperature):
    if provider == "torch":
        return TorchLMHeadJSD(
            H=H,
            V=V,
            dtype=dtype,
            device=torch_device,
            beta=beta,
            ignore_index=ignore_index,
            temperature=temperature,
        ).to(torch_device)
    if provider == "old_liger":
        return OldLigerLMHeadJSD(
            H=H,
            V=V,
            dtype=dtype,
            device=torch_device,
            beta=beta,
            ignore_index=ignore_index,
            temperature=temperature,
        ).to(torch_device)
    if provider == "new_liger":
        return NewLigerLMHeadJSD(
            H=H,
            V=V,
            dtype=dtype,
            device=torch_device,
            beta=beta,
            ignore_index=ignore_index,
            temperature=temperature,
        ).to(torch_device)
    raise ValueError(f"Unknown provider: {provider}")


def setup_case(provider, BT, H, V, dtype, torch_device, beta, ignore_index, temperature, with_labels, backward_mode="full"):
    model = make_model(provider, H, V, dtype, torch_device, beta, ignore_index, temperature)

    student_weight = torch.rand(V, H, device=torch_device, dtype=dtype)
    teacher_weight = torch.rand(V, H, device=torch_device, dtype=dtype)
    model.student_lin.weight.data = student_weight
    model.teacher_lin.weight.data = teacher_weight
    
    model.teacher_lin.weight.requires_grad = False
    model.student_lin.weight.requires_grad = True

    student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=torch_device)
    teacher_input = torch.rand(BT, H, requires_grad=False, dtype=dtype, device=torch_device)
    labels = None
    if with_labels:
        labels = torch.randint(0, V, (BT,), device=torch_device, dtype=torch.long)

    return model, student_input, teacher_input, labels


def tensors_to_clear(model, student_input):
    tensors = [student_input]
    for parameter in model.parameters():
        if parameter.requires_grad:
            tensors.append(parameter)
    return tensors


def test_correctness(model_factory, bt_values, dtype, torch_device, rep=5, backward_mode="full"):
    """Test correctness by comparing outputs and gradients across providers."""
    tol = CORRECTNESS_TOLERANCE.get(str(dtype), {"rtol": 1e-2, "atol": 1e-3})
    rtol, atol = tol["rtol"], tol["atol"]
    
    results = []
    providers = ["torch", "old_liger", "new_liger"]
    
    for bt in bt_values:
        outputs = {}
        grads = {}
        
        for provider in providers:
            torch.manual_seed(123)
            model, student_input, teacher_input, labels = model_factory(provider, bt)
            
            loss = model(student_input, teacher_input, labels)
            loss.backward()
            
            outputs[provider] = loss.item()
            grads[provider] = {
                "student_input": student_input.grad.clone() if student_input.grad is not None else None,
                "student_weight": model.student_lin.weight.grad.clone() if model.student_lin.weight.grad is not None else None,
                "teacher_weight": model.teacher_lin.weight.grad.clone() if model.teacher_lin.weight.grad is not None else None,
            }
        
        torch_output = outputs["torch"]
        torch_grads = grads["torch"]
        
        for provider in ["old_liger", "new_liger"]:
            output_diff = abs(outputs[provider] - torch_output)
            output_match = output_diff <= atol + rtol * abs(torch_output)
            
            grad_results = {}
            for grad_name in ["student_input", "student_weight", "teacher_weight"]:
                if torch_grads[grad_name] is not None and grads[provider][grad_name] is not None:
                    grad_diff = (grads[provider][grad_name] - torch_grads[grad_name]).abs()
                    max_grad_diff = grad_diff.max().item()
                    grad_match = max_grad_diff <= atol + rtol * torch_grads[grad_name].abs().max().item()
                    grad_results[grad_name] = {"max_diff": max_grad_diff, "match": grad_match}
                else:
                    grad_results[grad_name] = {"max_diff": None, "match": True}
            
            results.append({
                "BT": bt,
                "provider": provider,
                "output_diff": output_diff,
                "output_match": output_match,
                "grad_results": grad_results,
            })
    
    all_passed = all(r["output_match"] and all(g["match"] for g in r["grad_results"].values()) for r in results)
    
    print("\n" + "=" * 80)
    print(f"CORRECTNESS TEST RESULTS (backward_mode={backward_mode})")
    print("=" * 80)
    print(f"Tolerance: rtol={rtol}, atol={atol}")
    print(f"{'BT':<10} {'Provider':<12} {'Output Diff':<15} {'Output':<8} {'Grad Match':<12}")
    print("-" * 80)
    
    for r in results:
        grad_match = all(g["match"] for g in r["grad_results"].values())
        grad_match_str = "PASS" if grad_match else "FAIL"
        output_str = "PASS" if r["output_match"] else "FAIL"
        print(f"{r['BT']:<10} {r['provider']:<12} {r['output_diff']:<15.6e} {output_str:<8} {grad_match_str:<12}")
    
    print("-" * 80)
    if all_passed:
        print("CORRECTNESS TEST: PASSED")
    else:
        print("CORRECTNESS TEST: FAILED")
    print("=" * 80 + "\n")
    
    return all_passed


def benchmark_speed(model, student_input, teacher_input, labels, mode, rep):
    def fwd():
        return model(student_input, teacher_input, labels)

    grad_to_none = tensors_to_clear(model, student_input)
    if mode == "forward":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            fwd,
            grad_to_none=grad_to_none,
            rep=rep,
            quantiles=[0.5, 0.2, 0.8],
        )
    elif mode == "backward_full":
        loss = fwd()
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            lambda: loss.backward(retain_graph=True),
            grad_to_none=grad_to_none,
            rep=rep,
            quantiles=[0.5, 0.2, 0.8],
        )
    elif mode == "full":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            lambda: fwd().backward(retain_graph=True),
            grad_to_none=grad_to_none,
            rep=rep,
            quantiles=[0.5, 0.2, 0.8],
        )
    else:
        raise ValueError(f"Unsupported mode: {mode}")

    return {"p50": ms_50, "p20": ms_20, "p80": ms_80}


def benchmark_memory(model_factory, iterations):
    values = []
    for _ in range(iterations):
        getattr(torch, device).empty_cache()
        getattr(torch, device).memory.reset_peak_memory_stats()
        model, student_input, teacher_input, labels = model_factory()
        model(student_input, teacher_input, labels).backward()
        values.append(getattr(torch, device).max_memory_allocated() / 2**20)

    tensor = torch.tensor(values, dtype=torch.float32)
    mem_50, mem_20, mem_80 = torch.quantile(tensor, torch.tensor([0.5, 0.2, 0.8])).tolist()
    return {"p50": mem_50, "p20": mem_20, "p80": mem_80}


def print_summary(rows):
    by_key = {(row["metric"], row["mode"], row["provider"], row["BT"]): row for row in rows}
    bts = sorted({row["BT"] for row in rows if row["metric"] == "speed_ms"})

    def p50(row):
        return row.get("p50") if row and "p50" in row else None

    def fmt_ms(value):
        return f"{value:.2f}" if value is not None else "-"

    def fmt_mb(value):
        return f"{value:.0f}" if value is not None else "-"

    def speedup(baseline, candidate):
        return f"{baseline / candidate:.2f}x" if baseline is not None and candidate else "-"

    def print_table(title, headers, table_rows):
        widths = [len(header) for header in headers]
        for row in table_rows:
            widths = [max(width, len(str(cell))) for width, cell in zip(widths, row)]

        def render(row):
            cells = []
            for idx, cell in enumerate(row):
                text = str(cell)
                cells.append(text.rjust(widths[idx]) if idx == 0 else text.rjust(widths[idx]))
            return "| " + " | ".join(cells) + " |"

        print(f"\n{title}")
        print(render(headers))
        print("| " + " | ".join("-" * width for width in widths) + " |")
        for row in table_rows:
            print(render(row))

    for mode, title in (
        ("forward", "Forward Speed (p50 ms)"),
        ("backward_full", "Backward Full Speed (p50 ms)"),
        ("full", "Full Pass Speed (p50 ms)"),
    ):
        table_rows = []
        for bt in bts:
            torch_p50 = p50(by_key.get(("speed_ms", mode, "torch", bt)))
            old_p50 = p50(by_key.get(("speed_ms", mode, "old_liger", bt)))
            new_p50 = p50(by_key.get(("speed_ms", mode, "new_liger", bt)))
            if torch_p50 is None and old_p50 is None and new_p50 is None:
                continue
            table_rows.append(
                [
                    bt,
                    fmt_ms(torch_p50),
                    fmt_ms(old_p50),
                    fmt_ms(new_p50),
                    speedup(old_p50, new_p50),
                    speedup(torch_p50, new_p50),
                ]
            )
        if table_rows:
            print_table(
                title,
                ["BT", "Torch", "Old", "New", "New vs Old", "New vs Torch"],
                table_rows,
            )

    memory_bts = sorted({row["BT"] for row in rows if row["metric"] == "memory_MB"})
    if memory_bts:
        table_rows = []
        for bt in memory_bts:
            torch_p50 = p50(by_key.get(("memory_MB", "full", "torch", bt)))
            old_p50 = p50(by_key.get(("memory_MB", "full", "old_liger", bt)))
            new_p50 = p50(by_key.get(("memory_MB", "full", "new_liger", bt)))
            table_rows.append(
                [
                    bt,
                    fmt_mb(torch_p50),
                    fmt_mb(old_p50),
                    fmt_mb(new_p50),
                    speedup(torch_p50, new_p50),
                ]
            )
        print_table(
            "Peak Memory (p50 MB)",
            ["BT", "Torch", "Old", "New", "Torch/New"],
            table_rows,
        )


def parse_args():
    parser = argparse.ArgumentParser(description="Compare torch, old Liger, and new Liger fused-linear JSD.")
    parser.add_argument("--bt-values", default="1024,4096,16384,65536,262144")
    parser.add_argument("--hidden-size", type=int, default=4096)
    parser.add_argument("--vocab-size", type=int, default=256000)
    parser.add_argument("--dtype", choices=sorted(DTYPE_MAP), default="bf16")
    parser.add_argument("--providers", default="torch,old_liger,new_liger")
    parser.add_argument("--modes", default="forward,backward_full,full")
    parser.add_argument("--rep", type=int, default=10)
    parser.add_argument("--memory-iters", type=int, default=10)
    parser.add_argument("--skip-memory", action="store_true")
    parser.add_argument("--with-labels", action="store_true")
    parser.add_argument("--beta", type=float, default=0.5)
    parser.add_argument("--ignore-index", type=int, default=-100)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--summary", action="store_true")
    parser.add_argument("--skip-correctness", action="store_true", help="Skip correctness test")
    parser.add_argument("--correctness-bt", default="16,32,64,128,256,512,1024,2048", help="BT values for correctness test")
    return parser.parse_args()


def main():
    args = parse_args()
    torch_device = torch.device(device)
    dtype = DTYPE_MAP[args.dtype]
    bt_values = parse_csv_list(args.bt_values, int)
    providers = parse_csv_list(args.providers)
    modes = parse_csv_list(args.modes)

    env = {
        "torch": torch.__version__,
        "triton": triton.__version__,
        "gpu": getattr(torch, device).get_device_name(),
        "H": args.hidden_size,
        "V": args.vocab_size,
        "dtype": str(dtype),
        "with_labels": args.with_labels,
    }
    print(json.dumps({"env": env}), flush=True)

    if not args.skip_correctness:
        correctness_bt_values = parse_csv_list(args.correctness_bt, int)
        
        backward_modes_to_test = ["full"]

        all_correctness_passed = True
        for backward_mode in backward_modes_to_test:
            def correctness_model_factory(provider, bt, bm=backward_mode):
                return setup_case(
                    provider,
                    bt,
                    args.hidden_size,
                    args.vocab_size,
                    dtype,
                    torch_device,
                    args.beta,
                    args.ignore_index,
                    args.temperature,
                    args.with_labels,
                    backward_mode=bm,
                )
            
            correctness_passed = test_correctness(
                lambda provider, bt: correctness_model_factory(provider, bt),
                correctness_bt_values,
                dtype,
                torch_device,
                backward_mode=backward_mode,
            )
            
            if not correctness_passed:
                all_correctness_passed = False
        
        if not all_correctness_passed:
            print("WARNING: Correctness test failed. Benchmark results may be invalid.")
            import sys
            sys.exit(1)

    rows = []
    for mode in modes:
        backward_mode = "full" if mode == "backward_full" else None
        
        for provider in providers:
            for bt in bt_values:
                torch.manual_seed(args.seed)
                getattr(torch, device).empty_cache()
                model, student_input, teacher_input, labels = setup_case(
                    provider,
                    bt,
                    args.hidden_size,
                    args.vocab_size,
                    dtype,
                    torch_device,
                    args.beta,
                    args.ignore_index,
                    args.temperature,
                    args.with_labels,
                    backward_mode=backward_mode,
                )
                try:
                    result = benchmark_speed(model, student_input, teacher_input, labels, mode, args.rep)
                    row = {"metric": "speed_ms", "mode": mode, "provider": provider, "BT": bt, **result}
                except Exception as error:
                    row = {
                        "metric": "speed_ms",
                        "mode": mode,
                        "provider": provider,
                        "BT": bt,
                        "error": repr(error),
                    }
                rows.append(row)
                print(json.dumps(row), flush=True)

    if not args.skip_memory:
        for provider in providers:
            for bt in bt_values:

                def model_factory(provider=provider, bt=bt):
                    torch.manual_seed(args.seed)
                    return setup_case(
                        provider,
                        bt,
                        args.hidden_size,
                        args.vocab_size,
                        dtype,
                        torch_device,
                        args.beta,
                        args.ignore_index,
                        args.temperature,
                        args.with_labels,
                        backward_mode="full",
                    )

                try:
                    result = benchmark_memory(model_factory, args.memory_iters)
                    row = {"metric": "memory_MB", "mode": "full", "provider": provider, "BT": bt, **result}
                except Exception as error:
                    row = {
                        "metric": "memory_MB",
                        "mode": "full",
                        "provider": provider,
                        "BT": bt,
                        "error": repr(error),
                    }
                rows.append(row)
                print(json.dumps(row), flush=True)

    if args.summary:
        print_summary(rows)


if __name__ == "__main__":
    main()

Comment on lines +68 to +69
@triton.jit
def _jsd_lm_head_kernel(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the key difference between this and the original jsd kernel? If the inefficiency exists in the jsd kernel, we should also optimize it

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_jsd_kernel only owns the JSD over log-prob inputs, so it correctly returns dL/dlog_q; any upstream log_softmax backward is handled outside that kernel. In the fused LM-head path, we bypass that upstream autograd path, so _jsd_lm_head_kernel has to fold the log-softmax backward into the kernel and produce dL/dlogits directly.

jsd.py allocates a [BT,V] and uses torch.sum. Since the final loss only needs 1 value per-row, we can accumulate inside the kernel and store [BT]. I can optimize jsd.py with similar changes as well

Comment on lines +56 to +63
chunk_memory_mb = _get_positive_int_env(CHUNK_MEMORY_MB_ENV, DEFAULT_CHUNK_MEMORY_MB)
if chunk_memory_mb is not None:
# The fast path keeps multiple fp32 (chunk, V) intermediates alive.
# Budget for roughly four such tensors: student/teacher logits and
# student/teacher log-probs. Use a power-of-two cap to avoid odd GEMMs.
bytes_per_token = 4 * V * torch.float32.itemsize
max_chunk_size = max(1, (chunk_memory_mb * 2**20) // bytes_per_token)
chunk_size = min(chunk_size, _previous_power_of_2(max_chunk_size))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it mean that chunk_size calculated based on chunk_memory_mb will override chunk_size set via min_chunk_size?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, chunk_memory_mb can still cap the final chunk size below min_chunk_size. I did this because the memory budget is treated as a hard cap, while min_chunk_size is only a lower bound for the adaptive heuristic when the memory cap allows it. I updated the comment/code to make this clearer

Comment on lines +19 to +23
DEFAULT_CHUNK_MEMORY_MB = 1024
DEFAULT_MIN_CHUNK_SIZE = 256
CHUNK_SIZE_ENV = "LIGER_FUSED_LINEAR_JSD_CHUNK_SIZE"
CHUNK_MEMORY_MB_ENV = "LIGER_FUSED_LINEAR_JSD_CHUNK_MEMORY_MB"
MIN_CHUNK_SIZE_ENV = "LIGER_FUSED_LINEAR_JSD_MIN_CHUNK_SIZE"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of setting environment variables, it makes more sense to me to expose these chunk related parameters for users in our fused_linear_ function family. what do you think? cc @Mecoli1219 @vaibhavjindal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants