Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
addressed PR comments
  • Loading branch information
jignparm committed Nov 16, 2018
commit 4b689b382f2cb53ed87f15805d28a89b633c1f04
84 changes: 36 additions & 48 deletions src/Microsoft.ML.OnnxTransform/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "ONNXSCOR",
verWrittenCur: 0x00010002, // Initial
// version 10001 is single input & output.
// version 10002 = multiple inputs & outputs
verWrittenCur: 0x00010002,
verReadableCur: 0x00010002,
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Nov 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verReadableCur: 0x00010002, [](start = 16, length = 27)

would be nice to add comment what exactly changed, like // Inputs and Outputs columns now array of strings. #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


In reply to: 233267446 [](ancestors = 233267446)

verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
Expand Down Expand Up @@ -101,23 +103,16 @@ private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx)
if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();

bool isMultiOutput = ctx.Header.ModelVerReadable > 0x00010001;
bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001;

//var inputColumn = ctx.LoadNonEmptyString();
//var outputColumn = ctx.LoadNonEmptyString();

var numInputs = 1;
if (isMultiOutput)
numInputs = ctx.Reader.ReadInt32();
var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

env.CheckDecode(numInputs > 0);
var inputs = new string[numInputs];
for (int j = 0; j < inputs.Length; j++)
inputs[j] = ctx.LoadNonEmptyString();

var numOutputs = 1;
if (isMultiOutput)
numOutputs = ctx.Reader.ReadInt32();
var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

env.CheckDecode(numOutputs > 0);
var outputs = new string[numOutputs];
Expand Down Expand Up @@ -154,39 +149,29 @@ private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes =
Model = OnnxModel.CreateFromBytes(modelBytes);

var modelInfo = Model.ModelInfo;
//if (modelInfo.InputsInfo.Length != 1)
// throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s).");
//if (modelInfo.OutputsInfo.Length != 1)
// throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s).");

Inputs = args.InputColumns;
Outputs = args.OutputColumns;
//var type = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
//var shape = outputNodeInfo.Shape;
//var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select( x => (int) x ).ToArray() : new[] { 0 };

OutputTypes = new ColumnType[args.OutputColumns.Length];

var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;
for (int i=0; i < args.OutputColumns.Length; i++)
{
var idx = -1;
for (var j = 0; j < Model.ModelInfo.OutputsInfo.Length; j++)
if (Model.ModelInfo.OutputsInfo[j].Name == args.OutputColumns[i])
{
idx = j;
break;
}
var idx = Array.IndexOf(Model.GetOutputNames(), args.OutputColumns[i]);
if (idx < 0)
throw _host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model");
Copy link

@shmoradims shmoradims Nov 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make OnnxModel._outputNames public and simplify code to:
if (!Model.OutputNames.Contains(args.OutputColumns[i])
throw ... #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified .


In reply to: 232835603 [](ancestors = 232835603)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use Contains() instead of IndexOf?


In reply to: 234085225 [](ancestors = 234085225,232835603)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line of code just below uses the value -- we need to know the index since the user specified inputs might not be in the same order as Model.ModelInfo.OutputsInfo.


In reply to: 234305001 [](ancestors = 234305001,234085225,232835603)


var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
var shape = outputNodeInfo.Shape;
var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 };
var dims = AdjustDimensions(shape);
OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims);
}
_args = args;
}

public OnnxTransform(IHostEnvironment env, string modelFile, string inputColumn, string outputColumn)
: this(env, new Arguments() { ModelFile = modelFile, InputColumns = new[] { inputColumn }, OutputColumns = new[] { outputColumn } })
{
}

public OnnxTransform(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns)
: this(env, new Arguments() { ModelFile = modelFile, InputColumns = inputColumns, OutputColumns = outputColumns })
{
Expand Down Expand Up @@ -221,15 +206,13 @@ public void Save(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());

ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(Model.ToByteArray()); });
//ctx.SaveNonEmptyString(_args.InputColumn);
//ctx.SaveNonEmptyString(_args.OutputColumn);

_host.AssertNonEmpty(Inputs);
_host.CheckNonEmpty(Inputs, nameof(Inputs));
ctx.Writer.Write(Inputs.Length);
foreach (var colName in Inputs)
ctx.SaveNonEmptyString(colName);

_host.AssertNonEmpty(Outputs);
_host.CheckNonEmpty(Outputs, nameof(Outputs));
ctx.Writer.Write(Outputs.Length);
foreach (var colName in Outputs)
ctx.SaveNonEmptyString(colName);
Expand All @@ -243,6 +226,24 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
return MakeDataTransform(new EmptyDataView(_host, inputSchema));
}

private static int[] AdjustDimensions(OnnxShape shape)
{
if (shape.Count > 0)
Copy link

@shmoradims shmoradims Nov 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape.Count > 0 [](start = 16, length = 15)

is shape.Count == 0 a valid case or an error? right now it's being treated as a valid case. If it's valid please comment about what kind of models would have it. #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment -- let me know if it makes sense.


In reply to: 234307353 [](ancestors = 234307353)

{
// some models may have -1 in first position.
// skip this dimension when setting output column dimensions.
if (shape[0] < 0)
{
return shape.Skip(1).Select(x => (int)x).ToArray();
}
else
{
return shape.Select(x => (int)x).ToArray();
}
}
return new[] { 0 };
}

private sealed class Mapper : IRowMapper
{
private readonly IHost _host;
Expand All @@ -268,20 +269,13 @@ public Mapper(IHostEnvironment env, OnnxTransform parent, Schema inputSchema)
var model = _parent.Model;
for (int i = 0; i < _parent.Inputs.Length; i++)
{
var idx = -1;
for (var j = 0; j < model.ModelInfo.InputsInfo.Length; j++)
if (model.ModelInfo.InputsInfo[j].Name == _parent.Inputs[i])
{
idx = j;
break;
}
var idx = Array.IndexOf(model.GetInputNames(), _parent.Inputs[i]);
if (idx < 0)
throw _host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model");

var inputNodeInfo = model.ModelInfo.InputsInfo[idx];

var shape = inputNodeInfo.Shape;
int[] inputdims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 };
var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

var inputShape = inputNodeInfo.Shape;
Expand Down Expand Up @@ -506,14 +500,8 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());

var idx = -1;
var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
for (var j = 0; j < inputsInfo.Length; j++)
if (inputsInfo[j].Name == input)
{
idx = j;
break;
}
var idx = Array.IndexOf(Transformer.Model.GetInputNames(), input);
if (idx < 0)
Copy link

@shmoradims shmoradims Nov 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same simplification as before #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


In reply to: 232840337 [](ancestors = 232840337)

throw Host.Except($"Column {input} doesn't match input node names of model.");

Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.OnnxTransform/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ private OnnxNodeInfo[] GetOutputsInfo()
_modelManager.GetOutputShapesDict(_modelName, _ignoredVersion));
}

public string[] GetInputNames()
{
return _inputNames.ToArray();
}

public string[] GetOutputNames()
{
return _outputNames.ToArray();
}
Copy link

@shmoradims shmoradims Nov 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these new methods. Let just rename _inputNames to InputNames and make it readonly property. And List<> should be fine if you use Contains() #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done...made public. We need to use IndexOf, since we need the actual index (if it contains the item)


In reply to: 234305812 [](ancestors = 234305812)


private static OnnxNodeInfo[] DictToNodesInfo(
Dictionary<string, DataType> typeDict,
Dictionary<string, long[]> shapeDictArray)
Expand Down