Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1499,19 +1499,53 @@ struct ggml_backend_cuda_context {
ggml_cuda_pool & pool() {
return pool(device);
}

struct q8_1_cache_entry {
const void * src1_data = nullptr;
void * q8_1_buf = nullptr;
size_t buf_size = 0;
};
q8_1_cache_entry q8_1_cache;

void * q8_1_cache_get_or_alloc(const void * src1_data, size_t needed_size, bool * cache_hit) {
if (q8_1_cache.src1_data == src1_data && q8_1_cache.q8_1_buf) {
*cache_hit = true;
return q8_1_cache.q8_1_buf;
}
if (q8_1_cache.buf_size < needed_size) {
if (q8_1_cache.q8_1_buf) {
CUDA_CHECK(cudaFree(q8_1_cache.q8_1_buf));
}
CUDA_CHECK(cudaMalloc(&q8_1_cache.q8_1_buf, needed_size));
q8_1_cache.buf_size = needed_size;
}
q8_1_cache.src1_data = src1_data;
*cache_hit = false;
return q8_1_cache.q8_1_buf;
}

void q8_1_cache_invalidate() {
q8_1_cache.src1_data = nullptr;
}
};

struct ggml_cuda_mm_fusion_args_host {
const ggml_tensor * x_bias = nullptr;
const ggml_tensor * gate = nullptr;
const ggml_tensor * gate_bias = nullptr;
ggml_glu_op glu_op;
const ggml_tensor * rms_norm_src = nullptr;
const ggml_tensor * rms_norm_weights = nullptr;
float rms_norm_eps = 0.0f;
};
struct ggml_cuda_mm_fusion_args_device {
const void * x_bias = nullptr;
const void * gate = nullptr;
const void * gate_bias = nullptr;
ggml_glu_op glu_op;
const void * rms_norm_src = nullptr;
const void * rms_norm_weights = nullptr;
float rms_norm_eps = 0.0f;
};

struct ggml_cuda_kernel_launch_params {
Expand Down
27 changes: 25 additions & 2 deletions ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,29 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
return 0;
}

static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna3_5(const int DKQ, const int DV, const int ncols) {
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 4, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 4, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 128, 4, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 128, 4, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 128, 4, 64, 72)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 128, 4, 64, 72)

GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 4, 64, 80)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 4, 64, 80)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 4, 64, 80)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 4, 64, 80)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 4, 64, 80)
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 4, 64, 80)

return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
}

static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (GGML_CUDA_CC_IS_RDNA3_5(cc)) {
return ggml_cuda_fattn_tile_get_config_amd_rdna3_5(DKQ, DV, ncols);
}
if (GGML_CUDA_CC_IS_RDNA(cc)) {
return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
}
Expand All @@ -324,11 +345,13 @@ static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const in

static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
#ifdef GGML_USE_HIP
#ifdef RDNA
#ifdef RDNA3_5
return ggml_cuda_fattn_tile_get_config_amd_rdna3_5(DKQ, DV, ncols);
#elif defined(RDNA)
return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
#else
return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
#endif // RDNA
#endif // RDNA3_5
#else
#ifdef FAST_FP16_AVAILABLE
return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
Expand Down
32 changes: 32 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4240,6 +4240,36 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
return 2;
}

// RMS_NORM + MUL + MUL_MAT: fold rms_norm+mul+quantize into a single MMVQ dispatch
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {}) &&
i + 2 < cgraph->n_nodes &&
cgraph->nodes[i + 2]->op == GGML_OP_MUL_MAT) {

ggml_tensor * rms_node = cgraph->nodes[i];
ggml_tensor * mul_node = cgraph->nodes[i + 1];
ggml_tensor * mm_node = cgraph->nodes[i + 2];

if (ggml_node_has_n_uses(cgraph, i + 1, 1) &&
(mm_node->src[1] == mul_node) &&
ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {

float eps;
memcpy(&eps, rms_node->op_params, sizeof(float));

const ggml_tensor * mul_weights = (mul_node->src[0] == rms_node)
? mul_node->src[1] : mul_node->src[0];

ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.rms_norm_src = rms_node->src[0];
fusion_data.rms_norm_weights = mul_weights;
fusion_data.rms_norm_eps = eps;

ggml_cuda_mul_mat_vec_q(*cuda_ctx, mm_node->src[0], mm_node->src[1],
mm_node->src[2], mm_node, &fusion_data);
return 2;
}
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]);
return 1;
Expand Down Expand Up @@ -4373,6 +4403,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
stream_ctx.concurrent_events.clear();
}

cuda_ctx->q8_1_cache_invalidate();

for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (is_concurrent_event_active) {
Expand Down
182 changes: 172 additions & 10 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2,
MMVQ_PARAMETERS_RDNA3_0,
MMVQ_PARAMETERS_RDNA3_5,
MMVQ_PARAMETERS_RDNA4
};

static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA4)
return MMVQ_PARAMETERS_RDNA4;
#elif defined(RDNA3_0)
#elif defined(RDNA3_5)
return MMVQ_PARAMETERS_RDNA3_5;
#elif defined(RDNA3)
return MMVQ_PARAMETERS_RDNA3_0;
#elif defined(RDNA2) || defined(RDNA3_5)
#elif defined(RDNA2)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
Expand All @@ -90,10 +93,13 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return MMVQ_PARAMETERS_RDNA4;
}
if (GGML_CUDA_CC_IS_RDNA3_5(cc)) {
return MMVQ_PARAMETERS_RDNA3_5;
}
if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
return MMVQ_PARAMETERS_RDNA3_0;
}
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
if (GGML_CUDA_CC_IS_RDNA2(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
Expand Down Expand Up @@ -422,6 +428,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
}
return 1;
}
if (table_id == MMVQ_PARAMETERS_RDNA3_5) {
return 1;
}
if (table_id == MMVQ_PARAMETERS_TURING) {
if (ncols_dst == 1) {
switch (type) {
Expand Down Expand Up @@ -674,6 +683,86 @@ static __global__ void mul_mat_vec_q(
}
}

// Zero scattered output locations for split-K.
// Grid: (nrows, nchannels_dst, nsamples) Block: (1)
static __global__ void mmvq_splitk_zero_output(
float * dst, const uint32_t stride_channel_dst, const uint32_t stride_sample_dst) {
dst[blockIdx.z*stride_sample_dst + blockIdx.y*stride_channel_dst + blockIdx.x] = 0.0f;
}

// Add bias to split-K output.
// Grid: (ceil(nrows/256), nchannels_dst, nsamples) Block: (256)
static __global__ void mmvq_splitk_add_bias(
float * dst, const float * bias, const uint32_t nrows,
const uint32_t stride_channel_dst, const uint32_t stride_sample_dst) {
const uint32_t row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nrows) {
return;
}
const uint32_t offset = blockIdx.z*stride_sample_dst + blockIdx.y*stride_channel_dst + row;
dst[offset] += bias[offset];
}

// Split-K MMVQ kernel for RDNA 3.5: splits each row's K-dimension across split_k_factor blocks.
// Each block computes a partial dot product and atomicAdds to the output.
// Grid: (nrows, nchannels_dst, split_k_factor * nsamples)
// Block: (warp_size, 1, 1) — single wave, no shared memory reduction.
template <ggml_type type, int split_k_factor>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q_splitk(
const void * vx_ptr, const void * vy_ptr, float * dst_ptr,
const uint32_t ncols_x, const uint3 nchannels_y,
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
const uint3 sample_ratio, const uint32_t stride_sample_x,
const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
const void * GGML_CUDA_RESTRICT vx = vx_ptr;
const void * GGML_CUDA_RESTRICT vy = vy_ptr;
float * GGML_CUDA_RESTRICT dst = dst_ptr;

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);

const int row0 = blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * warp_size / qi;

const uint32_t channel_dst = blockIdx.y;
const uint32_t split_chunk = blockIdx.z % split_k_factor;
const uint32_t sample_dst = blockIdx.z / split_k_factor;

const uint32_t channel_x = fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = channel_dst;
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;

const int blocks_per_chunk = (blocks_per_row_x + split_k_factor - 1) / split_k_factor;
const int k_start = split_chunk * blocks_per_chunk;
const int k_end = min(k_start + blocks_per_chunk, blocks_per_row_x);

const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;

float tmp = 0.0f;

for (int kbx = k_start + threadIdx.x / (qi/vdr); kbx < k_end; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1);
const int kqs = vdr * (threadIdx.x % (qi/vdr));
tmp += vec_dot_q_cuda(vx, &y[kby], kbx_offset + kbx, kqs);
}

tmp = warp_reduce_sum<warp_size>(tmp);

if (threadIdx.x == 0) {
atomicAdd(&dst[sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0], tmp);
}
}

// Dedicated MoE multi-token kernel.
// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
// Block: (warp_size, ncols_dst) - each warp handles one token independently.
Expand Down Expand Up @@ -890,6 +979,57 @@ static void mul_mat_vec_q_switch_ncols_dst(
case 1: {
constexpr int c_ncols_dst = 1;

// Bias-only defusion for RDNA 3.5: run non-fused kernel + separate bias-add.
// The fused template generates slower code for bias-only cases on RDNA 3.5.
if (table_id == MMVQ_PARAMETERS_RDNA3_5 && has_fusion && fusion.gate == nullptr && fusion.x_bias != nullptr && !has_ids) {
const ggml_cuda_mm_fusion_args_device no_fusion{};
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, no_fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
stream);

const int bias_threads = 256;
const dim3 bias_grid((nrows_x + bias_threads - 1) / bias_threads, nchannels_dst, nsamples_dst);
mmvq_splitk_add_bias<<<bias_grid, bias_threads, 0, stream>>>(
dst, (const float *) fusion.x_bias, nrows_x, stride_channel_dst, stride_sample_dst);
CUDA_CHECK(cudaGetLastError());
break;
}

// Non-fused split-K for RDNA 3.5: increase wave count to hide LPDDR5X latency
if (table_id == MMVQ_PARAMETERS_RDNA3_5 && !has_fusion && !has_ids) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
const int blocks_per_row = ncols_x / qk;
const int waves_per_cu = nrows_x / 40;
if (waves_per_cu < 80 && blocks_per_row >= 4) {
const int split_k = (waves_per_cu < 40) ? 4 : 2;

const dim3 zero_grid(nrows_x, nchannels_dst, nsamples_dst);
mmvq_splitk_zero_output<<<zero_grid, 1, 0, stream>>>(dst, stride_channel_dst, stride_sample_dst);
CUDA_CHECK(cudaGetLastError());

const dim3 block_nums(nrows_x, nchannels_dst, split_k * nsamples_dst);
const dim3 block_dims(warp_size, 1, 1);
const ggml_cuda_kernel_launch_params launch_params(block_nums, block_dims, 0, stream);

if (split_k == 4) {
ggml_cuda_kernel_launch(mul_mat_vec_q_splitk<type, 4>, launch_params,
vx, vy, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
ggml_cuda_kernel_launch(mul_mat_vec_q_splitk<type, 2>, launch_params,
vx, vy, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
}
break;
}
}

bool use_small_k = should_use_small_k(c_ncols_dst);

if (use_small_k) {
Expand Down Expand Up @@ -1183,12 +1323,34 @@ void ggml_cuda_mul_mat_vec_q(
}

const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
const size_t q8_1_size = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1;

bool q8_cache_hit = false;
void * q8_1_ptr = ctx.q8_1_cache_get_or_alloc(src1->data, q8_1_size, &q8_cache_hit);

ggml_cuda_pool_alloc<char> src1_q8_1_pool;
if (!q8_1_ptr) {
src1_q8_1_pool.alloc(ctx.pool(), q8_1_size);
q8_1_ptr = src1_q8_1_pool.get();
}

if (!q8_cache_hit) {
if (fusion && fusion->rms_norm_src) {
const ggml_tensor * rms_src = fusion->rms_norm_src;
const size_t ts_rms = ggml_type_size(rms_src->type);
const int64_t rs01 = rms_src->nb[1] / ts_rms;
const int64_t rs02 = rms_src->nb[2] / ts_rms;
const int64_t rs03 = rms_src->nb[3] / ts_rms;
quantize_row_q8_1_rms_norm_cuda(
(const float *)rms_src->data, (const float *)fusion->rms_norm_weights->data,
q8_1_ptr, fusion->rms_norm_eps,
ne10, rs01, rs02, rs03, ne10_padded, ne11, ne12, ne13, stream);
} else {
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_row_q8_1_cuda(src1_d, nullptr, q8_1_ptr, src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
}
}

const int64_t s01 = src0->nb[1] / ts_src0;
Expand All @@ -1214,7 +1376,7 @@ void ggml_cuda_mul_mat_vec_q(
const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;

mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
src0->data, src0->type, q8_1_ptr, ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, ids_stride, stream);
Expand Down
Loading