-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Multiple inputs output support for OnnxTransform. #1586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a9528de
a7a9229
4b689b3
926cc35
f144251
6289485
6469985
b4268cc
84da83c
8cbd8bd
c96778b
674c612
66a2b55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,7 +67,9 @@ private static VersionInfo GetVersionInfo() | |
| { | ||
| return new VersionInfo( | ||
| modelSignature: "ONNXSCOR", | ||
| verWrittenCur: 0x00010002, // Initial | ||
| // version 10001 is single input & output. | ||
| // version 10002 = multiple inputs & outputs | ||
| verWrittenCur: 0x00010002, | ||
| verReadableCur: 0x00010002, | ||
| verWeCanReadBack: 0x00010001, | ||
| loaderSignature: LoaderSignature, | ||
|
|
@@ -101,23 +103,16 @@ private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx) | |
| if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray())) | ||
| throw env.ExceptDecode(); | ||
|
|
||
| bool isMultiOutput = ctx.Header.ModelVerReadable > 0x00010001; | ||
| bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001; | ||
|
|
||
| //var inputColumn = ctx.LoadNonEmptyString(); | ||
| //var outputColumn = ctx.LoadNonEmptyString(); | ||
|
|
||
| var numInputs = 1; | ||
| if (isMultiOutput) | ||
| numInputs = ctx.Reader.ReadInt32(); | ||
| var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; | ||
|
|
||
| env.CheckDecode(numInputs > 0); | ||
| var inputs = new string[numInputs]; | ||
| for (int j = 0; j < inputs.Length; j++) | ||
| inputs[j] = ctx.LoadNonEmptyString(); | ||
|
|
||
| var numOutputs = 1; | ||
| if (isMultiOutput) | ||
| numOutputs = ctx.Reader.ReadInt32(); | ||
| var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1; | ||
|
|
||
| env.CheckDecode(numOutputs > 0); | ||
| var outputs = new string[numOutputs]; | ||
|
|
@@ -154,39 +149,29 @@ private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = | |
| Model = OnnxModel.CreateFromBytes(modelBytes); | ||
|
|
||
| var modelInfo = Model.ModelInfo; | ||
| //if (modelInfo.InputsInfo.Length != 1) | ||
| // throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s)."); | ||
| //if (modelInfo.OutputsInfo.Length != 1) | ||
| // throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s)."); | ||
|
|
||
| Inputs = args.InputColumns; | ||
| Outputs = args.OutputColumns; | ||
| //var type = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type); | ||
| //var shape = outputNodeInfo.Shape; | ||
| //var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select( x => (int) x ).ToArray() : new[] { 0 }; | ||
|
|
||
| OutputTypes = new ColumnType[args.OutputColumns.Length]; | ||
|
|
||
| var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; | ||
| for (int i=0; i < args.OutputColumns.Length; i++) | ||
| { | ||
| var idx = -1; | ||
| for (var j = 0; j < Model.ModelInfo.OutputsInfo.Length; j++) | ||
| if (Model.ModelInfo.OutputsInfo[j].Name == args.OutputColumns[i]) | ||
| { | ||
| idx = j; | ||
| break; | ||
| } | ||
| var idx = Array.IndexOf(Model.GetOutputNames(), args.OutputColumns[i]); | ||
| if (idx < 0) | ||
| throw _host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model"); | ||
|
||
|
|
||
| var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; | ||
| var shape = outputNodeInfo.Shape; | ||
| var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; | ||
| var dims = AdjustDimensions(shape); | ||
| OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims); | ||
| } | ||
| _args = args; | ||
| } | ||
|
|
||
| public OnnxTransform(IHostEnvironment env, string modelFile, string inputColumn, string outputColumn) | ||
| : this(env, new Arguments() { ModelFile = modelFile, InputColumns = new[] { inputColumn }, OutputColumns = new[] { outputColumn } }) | ||
| { | ||
| } | ||
|
|
||
| public OnnxTransform(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns) | ||
| : this(env, new Arguments() { ModelFile = modelFile, InputColumns = inputColumns, OutputColumns = outputColumns }) | ||
| { | ||
|
|
@@ -221,15 +206,13 @@ public void Save(ModelSaveContext ctx) | |
| ctx.SetVersionInfo(GetVersionInfo()); | ||
|
|
||
| ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(Model.ToByteArray()); }); | ||
| //ctx.SaveNonEmptyString(_args.InputColumn); | ||
| //ctx.SaveNonEmptyString(_args.OutputColumn); | ||
|
|
||
| _host.AssertNonEmpty(Inputs); | ||
| _host.CheckNonEmpty(Inputs, nameof(Inputs)); | ||
| ctx.Writer.Write(Inputs.Length); | ||
| foreach (var colName in Inputs) | ||
| ctx.SaveNonEmptyString(colName); | ||
|
|
||
| _host.AssertNonEmpty(Outputs); | ||
| _host.CheckNonEmpty(Outputs, nameof(Outputs)); | ||
| ctx.Writer.Write(Outputs.Length); | ||
| foreach (var colName in Outputs) | ||
| ctx.SaveNonEmptyString(colName); | ||
|
|
@@ -243,6 +226,24 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) | |
| return MakeDataTransform(new EmptyDataView(_host, inputSchema)); | ||
| } | ||
|
|
||
| private static int[] AdjustDimensions(OnnxShape shape) | ||
| { | ||
| if (shape.Count > 0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
is shape.Count == 0 a valid case or an error? right now it's being treated as a valid case. If it's valid please comment about what kind of models would have it. #Closed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| { | ||
| // some models may have -1 in first position. | ||
| // skip this dimension when setting output column dimensions. | ||
| if (shape[0] < 0) | ||
| { | ||
| return shape.Skip(1).Select(x => (int)x).ToArray(); | ||
| } | ||
| else | ||
| { | ||
| return shape.Select(x => (int)x).ToArray(); | ||
| } | ||
| } | ||
| return new[] { 0 }; | ||
| } | ||
|
|
||
| private sealed class Mapper : IRowMapper | ||
| { | ||
| private readonly IHost _host; | ||
|
|
@@ -268,20 +269,13 @@ public Mapper(IHostEnvironment env, OnnxTransform parent, Schema inputSchema) | |
| var model = _parent.Model; | ||
| for (int i = 0; i < _parent.Inputs.Length; i++) | ||
| { | ||
| var idx = -1; | ||
| for (var j = 0; j < model.ModelInfo.InputsInfo.Length; j++) | ||
| if (model.ModelInfo.InputsInfo[j].Name == _parent.Inputs[i]) | ||
| { | ||
| idx = j; | ||
| break; | ||
| } | ||
| var idx = Array.IndexOf(model.GetInputNames(), _parent.Inputs[i]); | ||
| if (idx < 0) | ||
| throw _host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model"); | ||
|
|
||
| var inputNodeInfo = model.ModelInfo.InputsInfo[idx]; | ||
|
|
||
| var shape = inputNodeInfo.Shape; | ||
| int[] inputdims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; | ||
| var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); | ||
|
|
||
| var inputShape = inputNodeInfo.Shape; | ||
|
|
@@ -506,14 +500,8 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) | |
| if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector)) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString()); | ||
|
|
||
| var idx = -1; | ||
| var inputsInfo = Transformer.Model.ModelInfo.InputsInfo; | ||
| for (var j = 0; j < inputsInfo.Length; j++) | ||
| if (inputsInfo[j].Name == input) | ||
| { | ||
| idx = j; | ||
| break; | ||
| } | ||
| var idx = Array.IndexOf(Transformer.Model.GetInputNames(), input); | ||
| if (idx < 0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same simplification as before #Closed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| throw Host.Except($"Column {input} doesn't match input node names of model."); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -124,6 +124,16 @@ private OnnxNodeInfo[] GetOutputsInfo() | |
| _modelManager.GetOutputShapesDict(_modelName, _ignoredVersion)); | ||
| } | ||
|
|
||
| public string[] GetInputNames() | ||
| { | ||
| return _inputNames.ToArray(); | ||
| } | ||
|
|
||
| public string[] GetOutputNames() | ||
| { | ||
| return _outputNames.ToArray(); | ||
| } | ||
|
||
|
|
||
| private static OnnxNodeInfo[] DictToNodesInfo( | ||
| Dictionary<string, DataType> typeDict, | ||
| Dictionary<string, long[]> shapeDictArray) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to add comment what exactly changed, like // Inputs and Outputs columns now array of strings. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
In reply to: 233267446 [](ancestors = 233267446)