feat: optimize fused linear JSD memory and runtime#1261
Conversation
| @triton.jit | ||
| def _jsd_lm_head_kernel( |
There was a problem hiding this comment.
what's the key difference between this and the original jsd kernel? If the inefficiency exists in the jsd kernel, we should also optimize it
There was a problem hiding this comment.
_jsd_kernel only owns the JSD over log-prob inputs, so it correctly returns dL/dlog_q; any upstream log_softmax backward is handled outside that kernel. In the fused LM-head path, we bypass that upstream autograd path, so _jsd_lm_head_kernel has to fold the log-softmax backward into the kernel and produce dL/dlogits directly.
jsd.py allocates a [BT,V] and uses torch.sum. Since the final loss only needs 1 value per-row, we can accumulate inside the kernel and store [BT]. I can optimize jsd.py with similar changes as well
| chunk_memory_mb = _get_positive_int_env(CHUNK_MEMORY_MB_ENV, DEFAULT_CHUNK_MEMORY_MB) | ||
| if chunk_memory_mb is not None: | ||
| # The fast path keeps multiple fp32 (chunk, V) intermediates alive. | ||
| # Budget for roughly four such tensors: student/teacher logits and | ||
| # student/teacher log-probs. Use a power-of-two cap to avoid odd GEMMs. | ||
| bytes_per_token = 4 * V * torch.float32.itemsize | ||
| max_chunk_size = max(1, (chunk_memory_mb * 2**20) // bytes_per_token) | ||
| chunk_size = min(chunk_size, _previous_power_of_2(max_chunk_size)) |
There was a problem hiding this comment.
does it mean that chunk_size calculated based on chunk_memory_mb will override chunk_size set via min_chunk_size?
There was a problem hiding this comment.
Yes, chunk_memory_mb can still cap the final chunk size below min_chunk_size. I did this because the memory budget is treated as a hard cap, while min_chunk_size is only a lower bound for the adaptive heuristic when the memory cap allows it. I updated the comment/code to make this clearer
| DEFAULT_CHUNK_MEMORY_MB = 1024 | ||
| DEFAULT_MIN_CHUNK_SIZE = 256 | ||
| CHUNK_SIZE_ENV = "LIGER_FUSED_LINEAR_JSD_CHUNK_SIZE" | ||
| CHUNK_MEMORY_MB_ENV = "LIGER_FUSED_LINEAR_JSD_CHUNK_MEMORY_MB" | ||
| MIN_CHUNK_SIZE_ENV = "LIGER_FUSED_LINEAR_JSD_MIN_CHUNK_SIZE" |
There was a problem hiding this comment.
Instead of setting environment variables, it makes more sense to me to expose these chunk related parameters for users in our fused_linear_ function family. what do you think? cc @Mecoli1219 @vaibhavjindal
Summary
A faster and more memory-efficient fused linear JSD kernel.
Testing Done
Hardware Type: NVIDIA A100-SXM4-80GBmake testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergenceAdditional Testing
Forward/Backward Correctness
Forward Speed (p50 ms)
Backward Full Speed (p50 ms)
Full Pass Speed (p50 ms)
Peak Memory (p50 MB)
Testing script
benchmark/script/compare_jsd.pysrc/liger_kernel/ops/fused_linear_jsd_old.pybenchmarking code