diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 49e2172ad3b..e9ff724042d 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -881,7 +881,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // op->src[0]->ne[0]); return true; } - if (op->type != GGML_TYPE_F32) { + if (op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 71fd90fae36..1954154835c 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -75,6 +75,11 @@ OutputVector translate_rope(const NodeContext & context) { } } + auto output_type = context.get_output_type(); + if (data_node->get_element_type() != ov::element::f32) { + data_node = std::make_shared(data_node, ov::element::f32); + } + if (mode == ROPE_TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); @@ -139,6 +144,10 @@ OutputVector translate_rope(const NodeContext & context) { res = std::make_shared(ov::OutputVector{sub, add}, 3); } + if (res.get_element_type() != output_type) { + res = std::make_shared(res, output_type); + } + return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 24384fcf674..b034dc79469 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -283,17 +283,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr 0) { - return atoi(chunk_size_str); + static int chunk_size = -1; + if (chunk_size == -1) { + const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE"); + if (chunk_size_str && atoi(chunk_size_str) > 0) { + chunk_size = atoi(chunk_size_str); + } else { + chunk_size = 256; + } } - return 256; + return chunk_size; }; static std::string device = "NPU"; static auto is_static = true; static auto stateful = false; - static auto prefill_chunk_size = get_prefill_chunk_size(); + + auto prefill_chunk_size = get_prefill_chunk_size(); const auto & config = ggml_openvino_get_compile_config(); if (is_naive(cgraph)) { @@ -357,6 +363,10 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + if (m_params.n_heads == -1) { + // graph is not a LLM, e.g. context-shift graph + prefill_chunk_size = inp_pos->ne[0]; + } auto ggml_decoder_prefill = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, stateful, false, true, prefill_chunk_size); auto ggml_decoder_decode = std::make_shared(cgraph, m_params, c_params, model_weights, is_static,