diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt index 691b711be..4b3b4d7f4 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt @@ -111,6 +111,9 @@ public open class TensorFlowInferenceModel( internal fun FloatData.toTensor(): Tensor { val preparedData = serializeToBuffer(floats) + require(preparedData.remaining() == shape.dims().reduce(Long::times).toInt()) { + // TODO: add more details about the shape and how to fix it. + } return Tensor.create(longArrayOf(1L, *shape.dims()), preparedData) }