Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix rdb decode v0 to fit the new changes.
  • Loading branch information
alonre24 committed Apr 12, 2021
commit 989bd6e1efb37b1a071d52467cd48302fa54c454
119 changes: 61 additions & 58 deletions src/serialization/RDB/decoder/previous/v0/decode_v0.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,25 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) {
char *devicestr = NULL;
RedisModuleString *tag = NULL;
size_t ninputs = 0;
const char **inputs = NULL;
char **inputs = NULL;
size_t noutputs = 0;
const char **outputs = NULL;
char **outputs = NULL;
char *buffer = NULL;

RAI_Error err = {0};
char *error_str = "Experienced a short read while reading a model from RDB";

RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
RedisModuleString *key_str =
RedisModule_CreateStringFromString(NULL, RedisModule_GetKeyNameFromIO(io));
if (!key_str) {
RedisModule_LogIOError(io, "error", "Couldn't get model key name from RDB");
return NULL;
}
RAI_Backend backend = RedisModule_LoadUnsigned(io);
devicestr = RedisModule_LoadStringBuffer(io, NULL);
size_t len;
char *cstr_tag = RedisModule_LoadStringBuffer(io, &len);
tag = RedisModule_CreateString(NULL, cstr_tag, len - 1);
size_t tag_len;
char *cstr_tag = RedisModule_LoadStringBuffer(io, &tag_len);
tag = RedisModule_CreateString(NULL, cstr_tag, tag_len - 1);
RedisModule_Free(cstr_tag);

const size_t batchsize = RedisModule_LoadUnsigned(io);
Expand All @@ -84,21 +93,17 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) {
ninputs = RedisModule_LoadUnsigned(io);
if (RedisModule_IsIOError(io))
goto cleanup;

inputs = RedisModule_Alloc(ninputs * sizeof(char *));

inputs = array_new(char *, ninputs);
for (size_t i = 0; i < ninputs; i++) {
inputs[i] = RedisModule_LoadStringBuffer(io, NULL);
inputs = array_append(inputs, RedisModule_LoadStringBuffer(io, NULL));
}

noutputs = RedisModule_LoadUnsigned(io);
if (RedisModule_IsIOError(io))
goto cleanup;

outputs = RedisModule_Alloc(noutputs * sizeof(char *));

outputs = array_new(char *, noutputs);
for (size_t i = 0; i < noutputs; i++) {
outputs[i] = RedisModule_LoadStringBuffer(io, NULL);
outputs = array_append(outputs, RedisModule_LoadStringBuffer(io, NULL));
}

RAI_ModelOpts opts = {
Expand All @@ -108,78 +113,76 @@ void *RAI_RDBLoadModel_v0(RedisModuleIO *io) {
.backends_inter_op_parallelism = getBackendsInterOpParallelism(),
};

size_t len;
buffer = RedisModule_LoadStringBuffer(io, &len);
if (RedisModule_IsIOError(io))
goto cleanup;

RAI_Error err = {0};
RAI_Model *model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs,
outputs, buffer, len, &err);

if (err.code == RAI_EBACKENDNOTLOADED) {
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
int ret = RAI_LoadDefaultBackend(ctx, backend);
if (ret == REDISMODULE_ERR) {
RedisModule_Log(ctx, "error", "Could not load default backend");
RAI_ClearError(&err);
RAI_Model *model = RedisModule_Calloc(1, sizeof(*model));
model->refCount = 1;
model->infokey = RAI_HoldString(NULL, key_str);
model->backend = backend;
model->devicestr = devicestr;
model->tag = tag;
model->inputs = inputs;
model->ninputs = ninputs;
model->outputs = outputs;
model->noutputs = noutputs;
model->opts = opts;
model->data = buffer;
model->datalen = len;

const char *backend_str = RAI_BackendName(backend);
if (ModelCreateBE(model, &err) != REDISMODULE_OK) {
// If we got an error *not* because of lazy loading, we fail and unblock.
if (RAI_GetErrorCode(&err) != RAI_EBACKENDNOTLOADED) {
error_str = (char *)RAI_GetError(&err);
goto cleanup;
}
RedisModule_Log(ctx, "warning", "backend %s not loaded, will try loading default backend",
backend_str);
int ret = RAI_LoadDefaultBackend(NULL, model->backend);
if (ret != REDISMODULE_OK) {
sprintf(error_str, "could not load %s default backend", backend_str);
goto cleanup;
}
// Try creating model for backend again.
RAI_ClearError(&err);
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs,
buffer, len, &err);
}

if (err.code != RAI_OK) {
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
RedisModule_Log(ctx, "error", "%s", err.detail);
RAI_ClearError(&err);
goto cleanup;
}

RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
RedisModuleString *stats_keystr =
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));

model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag);

for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free((void *)inputs[i]);
}
RedisModule_Free(inputs);
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free((void *)outputs[i]);
if (ModelCreateBE(model, &err) != REDISMODULE_OK) {
error_str = (char *)RAI_GetError(&err);
goto cleanup;
}
}
RedisModule_Free(outputs);
RedisModule_Free(buffer);
RedisModule_Free(devicestr);
RedisModule_FreeString(NULL, stats_keystr);
RedisModule_FreeString(NULL, tag);
RAI_AddStatsEntry(ctx, key_str, RAI_MODEL, backend, devicestr, tag);

return model;

cleanup:
if (devicestr)
RedisModule_Free(devicestr);
if (tag)
RedisModule_Free(tag);
RedisModule_FreeString(NULL, tag);
if (inputs) {
for (size_t i = 0; i < ninputs; i++) {
RedisModule_Free((void *)inputs[i]);
RedisModule_Free(inputs[i]);
}
RedisModule_Free(inputs);
array_free(inputs);
}

if (outputs) {
for (size_t i = 0; i < noutputs; i++) {
RedisModule_Free((void *)outputs[i]);
RedisModule_Free(outputs[i]);
}
RedisModule_Free(outputs);
array_free(outputs);
}

if (buffer)
RedisModule_Free(buffer);

RedisModule_LogIOError(io, "error", "Experienced a short read while reading a model from RDB");
RedisModule_LogIOError(io, "error", "%s", error_str);
if (RAI_GetErrorCode(&err) != RAI_OK) {
RAI_ClearError(&err);
}
return NULL;
}

Expand Down