Skip to content
Prev Previous commit
Next Next commit
Fix CPU execution
  • Loading branch information
lantiga committed Mar 14, 2021
commit 95c3a1ea691797d477bd136af60d11ef6622727f
37 changes: 7 additions & 30 deletions src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
TFE_ContextOptions *context_opts = TFE_NewContextOptions();

if (device == RAI_DEVICE_CPU) {
// Set number of GPU to 0 with
// config.device_count = {'GPU': 0}
uint8_t config[] = {0x0a, 0x07, 0x0a, 0x03, 0x47, 0x50, 0x55, 0x10, 0x00};
TFE_ContextOptionsSetConfig(context_opts, (void *)config, sizeof(config), status);

if (TF_GetCode(status) != TF_OK) {
RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status)));
goto cleanup;
}

if (opts.backends_intra_op_parallelism > 0) {
uint8_t proto[] = {0x10, (uint8_t)opts.backends_intra_op_parallelism};
TFE_ContextOptionsSetConfig(context_opts, proto, sizeof(proto), status);
Expand Down Expand Up @@ -679,13 +669,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
return REDISMODULE_ERR;
}

if (on_cpu) {
deviceInputTensorsHandles[i] = inputTensorsHandles[i];
}
else {
deviceInputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
inputTensorsHandles[i], mctxs[0]->model->session, tf_devicestr, status);
}
deviceInputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
inputTensorsHandles[i], mctxs[0]->model->session, tf_devicestr, status);

if (TF_GetCode(status) != TF_OK) {
char *errorMessage = RedisModule_Strdup(TF_Message(status));
Expand Down Expand Up @@ -716,6 +701,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {

int noutputs_ = noutputs;
TFE_Execute(fn_op, deviceOutputTensorsHandles, &noutputs_, status);

if (TF_GetCode(status) != TF_OK) {
char *errorMessage = RedisModule_Strdup(TF_Message(status));
RAI_SetError(error, RAI_EMODELRUN, errorMessage);
Expand All @@ -726,9 +712,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {

for (size_t i = 0; i < ninputs; ++i) {
TFE_DeleteTensorHandle(inputTensorsHandles[i]);
if (!on_cpu) {
TFE_DeleteTensorHandle(deviceInputTensorsHandles[i]);
}
TFE_DeleteTensorHandle(deviceInputTensorsHandles[i]);
}

if (TF_GetCode(status) != TF_OK) {
Expand All @@ -740,13 +724,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
}

for (size_t i = 0; i < noutputs; ++i) {
if (on_cpu) {
outputTensorsHandles[i] = deviceOutputTensorsHandles[i];
}
else {
outputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
deviceOutputTensorsHandles[i], mctxs[0]->model->session, "/device:CPU:0", status);
}
outputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
deviceOutputTensorsHandles[i], mctxs[0]->model->session, "/device:CPU:0", status);

DLManagedTensor *outputDLTensor = TFE_HandleToDLPack(outputTensorsHandles[i], status, error);

Expand Down Expand Up @@ -784,9 +763,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(outputTensor);
}
RAI_TensorFree(outputTensor);
if (!on_cpu) {
TFE_DeleteTensorHandle(deviceOutputTensorsHandles[i]);
}
TFE_DeleteTensorHandle(deviceOutputTensorsHandles[i]);
}

TF_DeleteStatus(status);
Expand Down
1 change: 1 addition & 0 deletions tests/flow/tests_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_run_tf_model(env):

with open(model_filename, 'rb') as f:
model_pb = f.read()
DEVICE = "CPU"

ret = con.execute_command('AI.MODELSET', 'm{1}', 'TF', DEVICE,
'INPUTS', 'a', 'b', 'OUTPUTS', 'mul', 'BLOB', model_pb)
Expand Down