Skip to content

Commit 2e1dd8d

Browse files
committed
Fix after rebasing
1 parent 82c3c54 commit 2e1dd8d

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

ggml/src/ggml-openvino/ggml-decoder.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,17 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
198198
if (node->src[0]->op != GGML_OP_VIEW) {
199199
m_op_case = 1;
200200
} else if (ggml_is_contiguous(node->src[0])) {
201-
// Permute kv cache (view)
202201
std::string src_name(node->view_src->name);
203-
int layer = extract_layer_from_name(src_name);
204-
if (!is_swa_layer(layer)) {
205-
m_op_case = 2;
202+
if (src_name.find("cache") == std::string::npos) {
203+
m_op_case = 1;
206204
} else {
207-
m_op_case = 3;
205+
// Permute kv cache (view)
206+
int layer = extract_layer_from_name(src_name);
207+
if (!is_swa_layer(layer)) {
208+
m_op_case = 2;
209+
} else {
210+
m_op_case = 3;
211+
}
208212
}
209213
}
210214
break;
@@ -230,6 +234,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
230234
}
231235
break;
232236
}
237+
case GGML_OP_VIEW: {
238+
if (node->src[0]->op == GGML_OP_VIEW) {
239+
auto* src = node->src[0];
240+
auto* view_src = src->view_src;
241+
if (view_src->ne[1] != src->ne[2]) {
242+
throw std::runtime_error("Unsupported VIEW case");
243+
}
244+
m_op_case = 2;
245+
}
246+
}
233247
default:
234248
break;
235249
}

ggml/src/ggml-openvino/openvino/op/set_rows.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ OutputVector translate_set_rows(const NodeContext& context) {
4545
false);
4646
auto indices_reshaped =
4747
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
48-
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data, zero);
48+
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
49+
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
50+
4951
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
5052
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
5153
return rename_outputs_with_suffix({res}, context.get_name());

ggml/src/ggml-openvino/openvino/op/view.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ namespace op {
99
OutputVector translate_view(const NodeContext& context) {
1010
num_inputs_check(context, 1, 1);
1111

12+
if (context.get_op_case() == 2) {
13+
auto dst_shape = context.get_output_shape(0).to_shape();
14+
return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[1] * dst_shape[2])}, context.get_name());
15+
}
1216
return {context.get_input(0)};
1317
}
1418

0 commit comments

Comments
 (0)