-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Multiple inputs output support for OnnxTransform. #1586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple inputs output support for OnnxTransform. #1586
Conversation
bool isMultiOutput = ctx.Header.ModelVerReadable > 0x00010001; | ||
|
||
//var inputColumn = ctx.LoadNonEmptyString(); | ||
//var outputColumn = ctx.LoadNonEmptyString(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s)."); | ||
//if (modelInfo.OutputsInfo.Length != 1) | ||
// throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s)."); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
break; | ||
} | ||
if (idx < 0) | ||
throw _host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make OnnxModel._outputNames public and simplify code to:
if (!Model.OutputNames.Contains(args.OutputColumns[i])
throw ... #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use Contains() instead of IndexOf?
In reply to: 234085225 [](ancestors = 234085225,232835603)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line of code just below uses the value -- we need to know the index since the user specified inputs might not be in the same order as Model.ModelInfo.OutputsInfo.
In reply to: 234305001 [](ancestors = 234305001,234085225,232835603)
throw _host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model"); | ||
var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; | ||
var shape = outputNodeInfo.Shape; | ||
var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; [](start = 16, length = 106)
This is too dense for me to follow. Please unpack the code into different cases and add comments about each case (meaning of negative count, why we skip, why skip 0 or 1, etc).
I've seen this later in the code to. If makes sense, please refactor to separate method. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.SaveNonEmptyString(_args.InputColumn); | ||
ctx.SaveNonEmptyString(_args.OutputColumn); | ||
//ctx.SaveNonEmptyString(_args.InputColumn); | ||
//ctx.SaveNonEmptyString(_args.OutputColumn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (model.ModelInfo.InputsInfo[j].Name == _parent.Inputs[i]) | ||
{ | ||
idx = j; | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inputs, you can do the same simplification you did with outputs. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here. you can use Contains(), which doesn't involve dealing with indices.
In reply to: 234080355 [](ancestors = 234080355,232837640)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var inputNodeInfo = model.ModelInfo.InputsInfo[idx]; | ||
|
||
var shape = inputNodeInfo.Shape; | ||
int[] inputdims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t[] inputdims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; [](start = 22, length = 111)
Is this the same logic as before? Please refactor to separate method if so. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this logic as well as variable -- not needed here due to shape calculation below. The other one was to set the output dimension of vector type column.
In reply to: 232837839 [](ancestors = 232837839)
_inputColIndices = new int[_parent.Inputs.Length]; | ||
_isInputVector = new bool[_parent.Inputs.Length]; | ||
_inputTensorShapes = new OnnxShape[_parent.Inputs.Length]; | ||
_inputOnnxTypes = new DataType[_parent.Inputs.Length]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it make sense to bundle all these into a InputMetadata class and just have one array of that class instead of multiple arrays? #WontFix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a matter of readability -- bundling will introduce another data structure. Also, this pattern is what's there in the TF transform -- later on, if we need to consolidate into a single transform which does inferencing of all DNN models, it would be beneficial to keep it this way.
Example here :
private readonly int[] _inputColIndices; |
In reply to: 232838745 [](ancestors = 232838745)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parallel arrays are considered an anti-pattern in OOD. Unfortunately, they're used often in our codebase. So, up to you if you want to keep it as is or refactor.
https://codeblog.jonskeet.uk/2014/06/03/anti-pattern-parallel-collections/
In reply to: 234077070 [](ancestors = 234077070,232838745)
idx = j; | ||
break; | ||
} | ||
if (idx < 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same simplification as before #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{ | ||
Contracts.Assert(toOutput.Length == 1); | ||
|
||
var outCol = (OutColumn)toOutput[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var outCol = (OutColumn)toOutput[0]; [](start = 16, length = 36)
Should we extend this also? It's only using 1 output. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF seems to have the same reconciler logic. It works fine as-is, although we should double check it.
var outCol = (OutColumn)toOutput[0]; |
In reply to: 232840804 [](ancestors = 232840804)
var inputColumn = ctx.LoadNonEmptyString(); | ||
var outputColumn = ctx.LoadNonEmptyString(); | ||
var args = new Arguments() { InputColumn = inputColumn, OutputColumn = outputColumn }; | ||
bool isMultiOutput = ctx.Header.ModelVerReadable > 0x00010001; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModelVerReadable [](start = 44, length = 16)
Should be ModelVerWritten I think #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. That sounds more correct -- updated. TF probably needs updating as well.
In reply to: 232842442 [](ancestors = 232842442)
|
||
var numInputs = 1; | ||
if (isMultiOutput) | ||
numInputs = ctx.Reader.ReadInt32(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var numInputs = isMultiOutput ? ctx.Reader.ReadInt32() : 1; #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var inputColumn = ctx.LoadNonEmptyString(); | ||
var outputColumn = ctx.LoadNonEmptyString(); | ||
var args = new Arguments() { InputColumn = inputColumn, OutputColumn = outputColumn }; | ||
bool isMultiOutput = ctx.Header.ModelVerReadable > 0x00010001; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isMultiOutput [](start = 17, length = 13)
supportsMultiInputOutput? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//ctx.SaveNonEmptyString(_args.InputColumn); | ||
//ctx.SaveNonEmptyString(_args.OutputColumn); | ||
|
||
_host.AssertNonEmpty(Inputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AssertNonEmpty [](start = 18, length = 14)
CheckNonEmpty #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
foreach (var colName in Inputs) | ||
ctx.SaveNonEmptyString(colName); | ||
|
||
_host.AssertNonEmpty(Outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AssertNonEmpty [](start = 18, length = 14)
CheckNonEmpty, public method should use Check instead of assert #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verWrittenCur: 0x00010001, // Initial | ||
verReadableCur: 0x00010001, | ||
verWrittenCur: 0x00010002, // Initial | ||
verReadableCur: 0x00010002, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verReadableCur: 0x00010002, [](start = 16, length = 27)
would be nice to add comment what exactly changed, like // Inputs and Outputs columns now array of strings. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | ||
|
||
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string inputColumn, string outputColumn) | ||
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] inputColumns, string[] outputColumns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public [](start = 8, length = 6)
can it be internal? #ByDesign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF is same way (see below)-- I assume this is required to be public due to different ways it can be instantiated.
public static IDataTransform Create(IHostEnvironment env, IDataView input, string model, string[] names, string[] source) |
In reply to: 233267511 [](ancestors = 233267511)
Outputs = args.OutputColumns; | ||
//var type = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type); | ||
//var shape = outputNodeInfo.Shape; | ||
//var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select( x => (int) x ).ToArray() : new[] { 0 }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete as well? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public OnnxTransform(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns) | ||
: this(env, new Arguments() { ModelFile = modelFile, InputColumns = inputColumns, OutputColumns = outputColumns }) | ||
{ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave old constructor, I don't see anything bad in having constructor for case in which you have only one input-output. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
foreach (var input in Inputs) | ||
{ | ||
if (!inputSchema.TryGetColumnIndex(input, out int srcCol)) | ||
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExceptSchemaMismatch [](start = 32, length = 20)
At what place we check type of input column? during transform? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check it in the other function GetOutputSchema, at line ~500. We check for shape mismatch, and type mismatch as well.
In reply to: 233269217 [](ancestors = 233269217)
public string[] GetOutputNames() | ||
{ | ||
return _outputNames.ToArray(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need these new methods. Let just rename _inputNames to InputNames and make it readonly property. And List<> should be fine if you use Contains() #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done...made public. We need to use IndexOf, since we need the actual index (if it contains the item)
In reply to: 234305812 [](ancestors = 234305812)
|
||
private static int[] AdjustDimensions(OnnxShape shape) | ||
{ | ||
if (shape.Count > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape.Count > 0 [](start = 16, length = 15)
is shape.Count == 0 a valid case or an error? right now it's being treated as a valid case. If it's valid please comment about what kind of models would have it. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marge with master
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransform))) | ||
{ | ||
//Contracts.CheckValue(env, nameof(env)); | ||
//Host = env.Register(RegistrationName); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to delete, Host now get created in base class (and sorry for this changes)
if (type.ValueCount % valCount != 0) | ||
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {type.ValueCount}."); | ||
|
||
//Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType); [](start = 20, length = 69)
Do we need it or not?
//// disposer = null; | ||
//// return Utils.MarshalInvoke(MakeGetter<int>, _outputItemRawType, input); | ||
////>>>>>>> master | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please delete :)
{ | ||
_srcgetter(ref _vBuffer); | ||
_vBuffer.CopyToDense(ref _vBufferDense); | ||
return OnnxUtils.CreateTensor(_vBufferDense.GetValues().ToArray(), _tensorShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_vBufferDense.GetValues().ToArray() [](start = 50, length = 35)
I would probably change CreateTensor to accept ReadonlySpan (or whatever GetValues returns) than cast every span to array.
Span is our new reality with VBuffers, you shouldn't expect arrays anymore. #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes issue #1585.
This PR is to add support for Onnx models that have multiple input/outputs. The current version of the transform allows for only single input and single output.