From 1b786034255ff1f1e6186051d4c9bd390c6e48b2 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Sun, 27 May 2018 22:10:38 -0700 Subject: [PATCH 01/22] Export to ONNX and Maml cross-platform executable --- Microsoft.ML.sln | 7 + .../BreastCancer/SaveModelToOnnxTest.json | 720 ++++++++++++++++++ .../Microsoft.ML.MamlExec.csproj | 28 + src/Microsoft.ML/Models/SaveAsOnnx.cs | 24 + .../Microsoft.ML.Tests.csproj | 1 + .../Scenarios/BinaryClassification.cs | 92 +++ .../Scenarios/HousePricePredictionTests.cs | 3 +- 7 files changed, 874 insertions(+), 1 deletion(-) create mode 100644 ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json create mode 100644 src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj create mode 100644 src/Microsoft.ML/Models/SaveAsOnnx.cs create mode 100644 test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 4974ca4887..389533be99 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -104,6 +104,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.MamlExec", "src\Microsoft.ML.MamlExec\Microsoft.ML.MamlExec.csproj", "{4217964C-891C-43DD-9E20-E77A5567DCE4}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -202,6 +204,10 @@ Global {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU + {4217964C-891C-43DD-9E20-E77A5567DCE4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4217964C-891C-43DD-9E20-E77A5567DCE4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4217964C-891C-43DD-9E20-E77A5567DCE4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4217964C-891C-43DD-9E20-E77A5567DCE4}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -236,6 +242,7 @@ Global {DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {4217964C-891C-43DD-9E20-E77A5567DCE4} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json new file mode 100644 index 0000000000..4dda997538 --- /dev/null +++ b/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json @@ -0,0 +1,720 @@ +{ + "producerName": "SaveModelToOnnxTest", + "domain": "Onnx", + "graph": { + "node": [ + { + "input": [ + "Features" + ], + "output": [ + "Score" + ], + "name": "TreeEnsembleRegressor", + "opType": "TreeEnsembleRegressor", + "attribute": [ + { + "name": "post_transform", + "s": "Tk9ORQ==", + "type": "STRING" + }, + { + "name": "n_targets", + "i": "1", + "type": "INT" + }, + { + "name": "base_values", + "floats": [ + 0 + ], + "type": "FLOATS" + }, + { + "name": "aggregate_function", + "s": "U1VN", + "type": "STRING" + }, + { + "name": "nodes_treeids", + "ints": [ + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "1", + "1", + "1", + "1", + "1", + "1", + "1", + "1", + "1", + "2", + "2", + "2", + "2", + "2", + "2", + "2", + "2", + "2", + "3", + "3", + "3", + "3", + "3", + "3", + "3", + "3", + "3", + "4", + "4", + "4", + "4", + "4", + "4", + "4", + "4", + "4" + ], + "type": "INTS" + }, + { + "name": "nodes_nodeids", + "ints": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8" + ], + "type": "INTS" + }, + { + "name": "nodes_featureids", + "ints": [ + "1", + "2", + "5", + "0", + "0", + "0", + "0", + "0", + "0", + "1", + "5", + "2", + "0", + "0", + "0", + "0", + "0", + "0", + "1", + "5", + "2", + "0", + "0", + "0", + "0", + "0", + "0", + "2", + "6", + "5", + "1", + "0", + "0", + "0", + "0", + "0", + "1", + "5", + "0", + "1", + "0", + "0", + "0", + "0", + "0" + ], + "type": "INTS" + }, + { + "name": "nodes_modes", + "strings": [ + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "QlJBTkNIX0xFUQ==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==", + "TEVBRg==" + ], + "type": "STRINGS" + }, + { + "name": "nodes_values", + "floats": [ + 2.5, + 2.5, + 5.5, + 5.5, + 0, + 0, + 0, + 0, + 0, + 3.5, + 3.5, + 1.5, + 3.5, + 0, + 0, + 0, + 0, + 0, + 3.5, + 2.5, + 2.5, + 5.5, + 0, + 0, + 0, + 0, + 0, + 3.5, + 3.5, + 2.5, + 4.5, + 0, + 0, + 0, + 0, + 0, + 2.5, + 1.5, + 6.5, + 5.5, + 0, + 0, + 0, + 0, + 0 + ], + "type": "FLOATS" + }, + { + "name": "nodes_truenodeids", + "ints": [ + "2", + "3", + "4", + "5", + "0", + "0", + "0", + "0", + "0", + "1", + "4", + "3", + "6", + "0", + "0", + "0", + "0", + "0", + "1", + "4", + "3", + "6", + "0", + "0", + "0", + "0", + "0", + "1", + "4", + "6", + "5", + "0", + "0", + "0", + "0", + "0", + "2", + "3", + "4", + "5", + "0", + "0", + "0", + "0", + "0" + ], + "type": "INTS" + }, + { + "name": "nodes_falsenodeids", + "ints": [ + "1", + "6", + "7", + "8", + "0", + "0", + "0", + "0", + "0", + "5", + "2", + "7", + "8", + "0", + "0", + "0", + "0", + "0", + "5", + "2", + "7", + "8", + "0", + "0", + "0", + "0", + "0", + "3", + "2", + "7", + "8", + "0", + "0", + "0", + "0", + "0", + "1", + "6", + "7", + "8", + "0", + "0", + "0", + "0", + "0" + ], + "type": "INTS" + }, + { + "name": "nodes_missing_value_tracks_true", + "ints": [ + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0" + ], + "type": "INTS" + }, + { + "name": "target_treeids", + "ints": [ + "0", + "0", + "0", + "0", + "0", + "1", + "1", + "1", + "1", + "1", + "2", + "2", + "2", + "2", + "2", + "3", + "3", + "3", + "3", + "3", + "4", + "4", + "4", + "4", + "4" + ], + "type": "INTS" + }, + { + "name": "target_nodeids", + "ints": [ + "4", + "5", + "6", + "7", + "8", + "4", + "5", + "6", + "7", + "8", + "4", + "5", + "6", + "7", + "8", + "4", + "5", + "6", + "7", + "8", + "4", + "5", + "6", + "7", + "8" + ], + "type": "INTS" + }, + { + "name": "target_ids", + "ints": [ + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0", + "0" + ], + "type": "INTS" + }, + { + "name": "target_weights", + "floats": [ + -0.9755501, + -0.8947368, + 0.8347107, + 0.75, + 1, + -0.8157646, + 0.75331986, + -0.8725711, + 0.6149202, + 1.05311215, + -0.739901066, + 0.65209645, + -0.618561566, + 0.6576947, + 0.7696665, + -0.663163662, + 0.229835972, + -0.49779135, + 0.670133352, + 0.660453737, + -0.620322645, + -0.634804964, + 0.513690054, + 0.650555968, + 0.6567067 + ], + "type": "FLOATS" + } + ], + "domain": "ai.onnx.ml" + }, + { + "input": [ + "Score" + ], + "output": [ + "linearOutput" + ], + "name": "Affine", + "opType": "Affine", + "attribute": [ + { + "name": "alpha", + "f": 0.4, + "type": "FLOAT" + }, + { + "name": "beta", + "f": -1E-07, + "type": "FLOAT" + } + ] + }, + { + "input": [ + "linearOutput" + ], + "output": [ + "Probability" + ], + "name": "Sigmoid", + "opType": "Sigmoid" + }, + { + "input": [ + "Probability" + ], + "output": [ + "PredictedLabel" + ], + "name": "Binarizer", + "opType": "Binarizer", + "attribute": [ + { + "name": "threshold", + "f": 0.5, + "type": "FLOAT" + } + ], + "domain": "ai.onnx.ml" + } + ], + "name": "SaveModelToOnnxTest", + "input": [ + { + "name": "Label", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Features", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "9" + } + ] + } + } + } + } + ], + "output": [ + { + "name": "PredictedLabel", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Score", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "Probability", + "type": { + "tensorType": { + "elemType": "FLOAT", + "shape": { + "dim": [ + { + "dimValue": "1" + }, + { + "dimValue": "1" + } + ] + } + } + } + } + ] + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj b/src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj new file mode 100644 index 0000000000..667aedf9b0 --- /dev/null +++ b/src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj @@ -0,0 +1,28 @@ + + + + true + CORECLR + netcoreapp2.0 + Microsoft.ML + Exe + Microsoft.ML.Runtime.Tools.Maml + + + + + + MAML.cs + + + ChainCommand.cs + + + HelpCommand.cs + + + VersionCommand.cs + + + + \ No newline at end of file diff --git a/src/Microsoft.ML/Models/SaveAsOnnx.cs b/src/Microsoft.ML/Models/SaveAsOnnx.cs new file mode 100644 index 0000000000..8165de521e --- /dev/null +++ b/src/Microsoft.ML/Models/SaveAsOnnx.cs @@ -0,0 +1,24 @@ +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model.Onnx; + +namespace Microsoft.ML.Models +{ + /// + /// Converts a model to ONNX format. + /// + public sealed class SaveAsOnnx + { + /// + /// Converts and then saves a model to ONNX format. + /// + /// Arguments such as input model file path, output ONNX file path, etc. + public static void Save(SaveOnnxCommand.Arguments args) + { + using (var env = new TlcEnvironment()) + { + var cmd = new SaveOnnxCommand(env, args); + cmd.Run(); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 733c669ffc..e92a2b85f0 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -10,6 +10,7 @@ + diff --git a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs new file mode 100644 index 0000000000..c6e870a1f2 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs @@ -0,0 +1,92 @@ +using Microsoft.ML.Data; +using Microsoft.ML.Models; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Trainers; +using System.IO; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + public class BreastCancerData + { + [Column(ordinal: "0")] + public float Label; + [Column(ordinal: "1-9")] + [VectorType(9)] + public float[] Features; + } + + public class BreastCancerPrediction + { + [ColumnName("PredictedLabel")] + public DvBool Cancerous; + } + + [Fact] + public void SaveModelToOnnxTest() + { + string dataPath = GetDataPath(@"breast-cancer.txt"); + var pipeline = new LearningPipeline(); + + pipeline.Add(new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "Features", + Source = new [] { new TextLoaderRange(1, 9) }, + Type = DataKind.Num + } + } + } + }); + + pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); + + PredictionModel model = pipeline.Train(); + + var modelOutpath = GetOutputPath(Path.Combine("..", "Common", + "Scenario", "BinaryClassification", "BreastCancer"), "SaveModelToOnnxTest.zip"); + + DeleteOutputPath(modelOutpath); + + var onnxPath = GetOutputPath(Path.Combine("..", "Common", + "Scenario", "BinaryClassification", "BreastCancer"), "SaveModelToOnnxTest.pb"); + + DeleteOutputPath(onnxPath); + + var onnxAsJsonPath = GetOutputPath(Path.Combine("..", "Common", + "Scenario", "BinaryClassification", "BreastCancer"), "SaveModelToOnnxTest.json"); + + DeleteOutputPath(onnxAsJsonPath); + + model.WriteAsync(modelOutpath); + SaveAsOnnx.Save(new Runtime.Model.Onnx.SaveOnnxCommand.Arguments + { + InputModelFile = modelOutpath, + OutputsToDrop = "Label,Features", + Onnx = onnxPath, + Json = onnxAsJsonPath, + Domain = "Onnx" + }); + + Assert.True(CheckEquality(Path.Combine("..", "Common", "Scenario", "BinaryClassification", "BreastCancer"), + "SaveModelToOnnxTest.json")); + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs index 392462a0eb..fdeb66123b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; @@ -12,7 +13,7 @@ namespace Microsoft.ML.Scenarios { - public partial class ScenariosTests : BaseTestClass + public partial class ScenariosTests : BaseTestBaseline { /* A real-estate firm Contoso wants to add a house price prediction to their ASP.NET/Xamarin application. From 1508cd4a1b4a389c8bd0db20f60972599c4a1f1d Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 29 May 2018 10:55:44 -0700 Subject: [PATCH 02/22] misc. --- src/Microsoft.ML.UniversalModelFormat/LotusIR/OnnxMl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.UniversalModelFormat/LotusIR/OnnxMl.md b/src/Microsoft.ML.UniversalModelFormat/LotusIR/OnnxMl.md index 4dd0a0fccd..c6fa3ecb6e 100644 --- a/src/Microsoft.ML.UniversalModelFormat/LotusIR/OnnxMl.md +++ b/src/Microsoft.ML.UniversalModelFormat/LotusIR/OnnxMl.md @@ -4,7 +4,7 @@ 2. Download protobuf C# compiler version 3.0 or greater from https://github.com/google/protobuf/tree/master/csharp 3. Add `option csharp_namespace = - "Microsoft.MachineLearning.Runtime.UniversalModelFormat.Onnx";` to `onnx-ml.proto3` right below `package ONNX_NAMESPACE;` + "Microsoft.ML.Runtime.UniversalModelFormat.Onnx";` to `onnx-ml.proto3` right below `package ONNX_NAMESPACE;` 4. Assuming the compiler and proto file are saved at `E:\protobuf-csharp-port\lib` then run the following in command line to get C# code from the proto file: ``` From c6763f33c97a6dbadb67384c8612022e9e1ed63f Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 00:40:38 -0700 Subject: [PATCH 03/22] PR feedback. --- Microsoft.ML.sln | 17 +- .../BreastCancer/SaveModelToOnnxTest.json | 18 - src/Microsoft.ML.Core/Data/ITransformModel.cs | 10 + .../EntryPoints/PredictorModel.cs | 12 + .../EntryPoints/TransformModel.cs | 10 + .../Model/Onnx/SaveOnnxCommand.cs | 124 ++++- src/Microsoft.ML.Maml/ChainCommand.cs | 84 --- src/Microsoft.ML.Maml/HelpCommand.cs | 509 ------------------ .../Microsoft.ML.Maml.csproj | 14 - src/Microsoft.ML.Maml/VersionCommand.cs | 38 -- .../Microsoft.ML.ResultProcessor.csproj | 2 +- src/Microsoft.ML/CSharpApi.cs | 69 +++ src/Microsoft.ML/Microsoft.ML.csproj | 2 +- src/Microsoft.ML/Models/SaveAsOnnx.cs | 24 - src/Microsoft.ML/PredictionModel.cs | 21 +- .../Microsoft.ML.TestFramework.csproj | 2 +- .../Scenarios/BinaryClassification.cs | 30 +- 17 files changed, 253 insertions(+), 733 deletions(-) delete mode 100644 src/Microsoft.ML.Maml/ChainCommand.cs delete mode 100644 src/Microsoft.ML.Maml/HelpCommand.cs delete mode 100644 src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj delete mode 100644 src/Microsoft.ML.Maml/VersionCommand.cs delete mode 100644 src/Microsoft.ML/Models/SaveAsOnnx.cs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 389533be99..91aa04754e 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -5,6 +5,9 @@ MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Core", "src\Microsoft.ML.Core\Microsoft.ML.Core.csproj", "{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{09EADF06-BE25-4228-AB53-95AE3E15B530}" + ProjectSection(SolutionItems) = preProject + src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj = src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj + EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AED9C836-31E3-4F3F-8ABC-929555D3F3C4}" EndProject @@ -30,7 +33,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{D956E291-F6E5-4474-9023-91793F45ABEB}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Maml\Microsoft.ML.Console.csproj", "{D956E291-F6E5-4474-9023-91793F45ABEB}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}" EndProject @@ -104,7 +107,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.MamlExec", "src\Microsoft.ML.MamlExec\Microsoft.ML.MamlExec.csproj", "{4217964C-891C-43DD-9E20-E77A5567DCE4}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Commands", "src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj", "{C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -204,10 +207,10 @@ Global {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU - {4217964C-891C-43DD-9E20-E77A5567DCE4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {4217964C-891C-43DD-9E20-E77A5567DCE4}.Debug|Any CPU.Build.0 = Debug|Any CPU - {4217964C-891C-43DD-9E20-E77A5567DCE4}.Release|Any CPU.ActiveCfg = Release|Any CPU - {4217964C-891C-43DD-9E20-E77A5567DCE4}.Release|Any CPU.Build.0 = Release|Any CPU + {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -242,7 +245,7 @@ Global {DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} - {4217964C-891C-43DD-9E20-E77A5567DCE4} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json index 4dda997538..eb76e1527e 100644 --- a/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json +++ b/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json @@ -623,24 +623,6 @@ ], "name": "SaveModelToOnnxTest", "input": [ - { - "name": "Label", - "type": { - "tensorType": { - "elemType": "FLOAT", - "shape": { - "dim": [ - { - "dimValue": "1" - }, - { - "dimValue": "1" - } - ] - } - } - } - }, { "name": "Features", "type": { diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index ccc73265ec..6075a7fe1c 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -32,6 +32,16 @@ public interface ITransformModel /// ISchema OutputSchema { get; } + /// + /// This contains the transforms to save instantiated on an IDataView with + /// appropriate initial schema. Note that the "root" of this is typically either + /// an empty IDataView or a BinaryLoader with no rows. However, other root + /// types are possible, since we don't insist on this when loading a model + /// from a zip file. However, whenever we save, we force a BinaryLoader to + /// be serialized for the root. + /// + IDataView View { get; } + /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs index af726fa758..5f3c92c029 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs @@ -36,6 +36,18 @@ public PredictorModel(IHostEnvironment env, RoleMappedData trainingData, IDataVi _predictor = predictor; } + //REVIEW: I'm not sure this is the right thing to do because we are setting predictor to null + //when this class is supposed to contain a predictor. TransformModel may or may not + //contain a predictor. Here we are just using this class as a wrapper for TransformModel + //so that we can use a single class to accept TransformModel and PredictorModel has inputs. + public PredictorModel(IHostEnvironment env, ITransformModel transformModel) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(transformModel, nameof(transformModel)); + + _transformModel = transformModel; + } + public PredictorModel(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index 9edc87df6d..783668f9d3 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -48,6 +48,16 @@ public sealed class TransformModel : ITransformModel /// public ISchema OutputSchema => _chain.Schema; + /// + /// This contains the transforms to save instantiated on an IDataView with + /// appropriate initial schema. Note that the "root" of this is typically either + /// an empty IDataView or a BinaryLoader with no rows. However, other root + /// types are possible, since we don't insist on this when loading a model + /// from a zip file. However, whenever we save, we force a BinaryLoader to + /// be serialized for the root. + /// + public IDataView View => _chain; + /// /// Create a TransformModel containing the transforms from "result" back to "input". /// diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs index 730916b031..cb1f13f81d 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs @@ -9,13 +9,17 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.Runtime.UniversalModelFormat.Onnx; using Newtonsoft.Json; [assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand), "Save ONNX", "SaveOnnx", DocName = "command/SaveOnnx.md")] +[assembly: LoadableClass(typeof(void), typeof(SaveOnnxCommand), null, typeof(SignatureEntryPointModule), "SaveOnnxCommand")] + namespace Microsoft.ML.Runtime.Model.Onnx { public sealed class SaveOnnxCommand : DataCommand.ImplBase @@ -40,11 +44,21 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of input column names to drop", ShortName = "idrop", SortOrder = 5)] public string InputsToDrop; - [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 6)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of input column names to drop", SortOrder = 6)] + public string[] InputsToDropArray; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 7)] public string OutputsToDrop; - [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 7)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", SortOrder = 8)] + public string[] OutputsToDropArray; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] public bool? LoadPredictor; + + [Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] + + public IPredictorModel Model; } private readonly string _outputModelPath; @@ -54,6 +68,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly bool? _loadPredictor; private readonly HashSet _inputsToDrop; private readonly HashSet _outputsToDrop; + private readonly IPredictorModel _model; public SaveOnnxCommand(IHostEnvironment env, Arguments args) : base(env, args, LoadName) @@ -68,18 +83,28 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args) _name = args.Name; _loadPredictor = args.LoadPredictor; - _inputsToDrop = CreateDropMap(args.InputsToDrop); - _outputsToDrop = CreateDropMap(args.OutputsToDrop); + _inputsToDrop = args.InputsToDropArray != null ? CreateDropMap(args.InputsToDropArray) : CreateDropMap(args.InputsToDrop); + _outputsToDrop = args.OutputsToDropArray != null ? CreateDropMap(args.OutputsToDropArray) : CreateDropMap(args.OutputsToDrop); _domain = args.Domain; + _model = args.Model; } private static HashSet CreateDropMap(string toDrop) { if (string.IsNullOrWhiteSpace(toDrop)) return new HashSet(); + return new HashSet(toDrop.Split(',')); } + private static HashSet CreateDropMap(string[] toDrop) + { + if (toDrop == null) + return new HashSet(); + + return new HashSet(toDrop); + } + public override void Run() { using (var ch = Host.Start("Run")) @@ -115,26 +140,39 @@ private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IData private void Run(IChannel ch) { - IDataLoader loader; + IDataLoader loader = null; ; IPredictor rawPred; - RoleMappedSchema trainSchema; + IDataView view; + RoleMappedSchema trainSchema = null; - if (string.IsNullOrEmpty(Args.InputModelFile)) + if (_model == null) { - loader = CreateLoader(); - rawPred = null; - trainSchema = null; - Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor), - "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified."); + if (string.IsNullOrEmpty(Args.InputModelFile)) + { + loader = CreateLoader(); + rawPred = null; + trainSchema = null; + Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor), + "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified."); + } + else + LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); + + view = loader; } else - LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); + { + view = _model.TransformModel.View; + rawPred = _model?.Predictor; + if (rawPred != null) + trainSchema = _model.GetTrainingSchema(Host); + } // Get the transform chain. IDataView source; IDataView end; LinkedList transforms; - GetPipe(ch, loader, out source, out end, out transforms); + GetPipe(ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); var ctx = new OnnxContext(Host, _name, _domain); @@ -228,10 +266,68 @@ private void Run(IChannel ch) if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) { + Contracts.Assert(loader != null); + ch.Trace("Saving the data pipe"); // Should probably include "end"? SaveLoader(loader, Args.OutputModelFile); } } + + public sealed class Output + { + //REVIEW: Would be nice to include ONNX protobuf model here but code generator needs an upgrade. + } + + //REVIEW: Ideally there is no need to define this input class and just reuse the Argument class from SaveONNX command + //but the code generator cannot parse certain complicated data types in the base class that Argument class extends. + //We should fix the code generator and use the Argument class. + public sealed class Input + { + [Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output ONNX to.", SortOrder = 1)] + public string Onnx; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output JSON to.", SortOrder = 2)] + public string Json; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The 'name' property in the output ONNX. By default this will be the ONNX extension-less name.", NullName = "", SortOrder = 3)] + public string Name; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The 'domain' property in the output ONNX.", NullName = "", SortOrder = 4)] + public string Domain; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Array of input column names to drop", SortOrder = 5)] + public string[] InputsToDrop; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Array of output column names to drop", SortOrder = 6)] + public string[] OutputsToDrop; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 7)] + public bool? LoadPredictor; + + [Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 8)] + + public IPredictorModel Model; + } + + + [TlcModule.EntryPoint(Name = "Models.OnnxConverter", Desc = "Converts the model to ONNX format.", UserName = "ONNX Converter.")] + public static Output Apply(IHostEnvironment env, Input input) + { + Arguments args = new Arguments(); + args.Onnx = input.Onnx; + args.Json = input.Json; + args.Name = input.Name; + args.Domain = input.Domain; + args.InputsToDropArray = input.InputsToDrop; + args.OutputsToDropArray = input.OutputsToDrop; + args.LoadPredictor = input.LoadPredictor; + args.Model = input.Model; + + var cmd = new SaveOnnxCommand(env, args); + cmd.Run(); + return new Output(); + } + } } diff --git a/src/Microsoft.ML.Maml/ChainCommand.cs b/src/Microsoft.ML.Maml/ChainCommand.cs deleted file mode 100644 index 829923a60c..0000000000 --- a/src/Microsoft.ML.Maml/ChainCommand.cs +++ /dev/null @@ -1,84 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Globalization; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Tools; - -[assembly: LoadableClass(ChainCommand.Summary, typeof(ChainCommand), typeof(ChainCommand.Arguments), typeof(SignatureCommand), - "Chain Command", "Chain")] - -namespace Microsoft.ML.Runtime.Tools -{ - using Stopwatch = System.Diagnostics.Stopwatch; - - public sealed class ChainCommand : ICommand - { - public sealed class Arguments - { -#pragma warning disable 649 // never assigned - [Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd")] - public SubComponent[] Command; -#pragma warning restore 649 // never assigned - } - - internal const string Summary = "A command that chains multiple other commands."; - - private readonly IHost _host; - - private readonly Arguments _args; - - public ChainCommand(IHostEnvironment env, Arguments args) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); - - _args = args; - _host = env.Register("Chain"); - } - - public void Run() - { - using (var ch = _host.Start("Run")) - { - var sw = new Stopwatch(); - int count = 0; - - sw.Start(); - if (_args.Command != null) - { - for (int i = 0; i < _args.Command.Length; i++) - { - using (var chCmd = _host.Start(string.Format(CultureInfo.InvariantCulture, "Command[{0}]", i))) - { - var sub = _args.Command[i]; - - chCmd.Info("====================================================================================="); - chCmd.Info("Executing: {0}", sub); - chCmd.Info("====================================================================================="); - - var cmd = sub.CreateInstance(_host); - cmd.Run(); - count++; - - chCmd.Info(" "); - - chCmd.Done(); - } - } - } - sw.Stop(); - - ch.Info("====================================================================================="); - ch.Info("Executed {0} commands in {1}", count, sw.Elapsed); - ch.Info("====================================================================================="); - - ch.Done(); - } - } - } -} diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs deleted file mode 100644 index e0941f5a93..0000000000 --- a/src/Microsoft.ML.Maml/HelpCommand.cs +++ /dev/null @@ -1,509 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Globalization; -using System.IO; -using System.Linq; -using System.Text; -using System.Xml.Linq; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Tools; - -[assembly: LoadableClass(HelpCommand.Summary, typeof(HelpCommand), typeof(HelpCommand.Arguments), typeof(SignatureCommand), - "MAML Help Command", "Help", "?")] - -[assembly: LoadableClass(typeof(XmlGenerator), typeof(XmlGenerator.Arguments), typeof(SignatureModuleGenerator), - "Xml generator", "XmlGenerator", "Xml")] - -namespace Microsoft.ML.Runtime.Tools -{ - public interface IGenerator - { - void Generate(IEnumerable infos); - } - - public delegate void SignatureModuleGenerator(string regenerate); - - public sealed class HelpCommand : ICommand - { - public sealed class Arguments - { -#pragma warning disable 649 // never assigned - [DefaultArgument(ArgumentType.AtMostOnce, HelpText = "The component name to get help for")] - public string Component; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The kind of component to look for", ShortName = "kind")] - public string Kind; - - [Argument(ArgumentType.AtMostOnce, HelpText = "List the component kinds", ShortName = "list")] - public bool ListKinds; - - [Argument(ArgumentType.AtMostOnce, ShortName = "all", Hide = true)] - public bool AllComponents; - - // extra DLLs for dynamic loading - [Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")] - public string[] ExtraAssemblies; - - [Argument(ArgumentType.LastOccurenceWins, Hide = true)] - public SubComponent Generator; -#pragma warning restore 649 // never assigned - } - - internal const string Summary = "Prints command line help."; - - private readonly IHostEnvironment _env; - private readonly string _component; - private readonly string _kind; - private readonly bool _listKinds; - private readonly bool _allComponents; - private readonly string[] _extraAssemblies; - private readonly IGenerator _generator; - - // REVIEW: Need to change the help command to use the provided host environment for output, - // instead of assuming the console. - public HelpCommand(IHostEnvironment env, Arguments args) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); - - _env = env; - _component = args.Component; - if (string.IsNullOrWhiteSpace(_component)) - _component = null; - - _kind = args.Kind; - if (string.IsNullOrWhiteSpace(_kind)) - _kind = null; - - _listKinds = args.ListKinds; - _allComponents = args.AllComponents; - - _extraAssemblies = args.ExtraAssemblies; - - if (args.Generator.IsGood()) - { - _generator = args.Generator.CreateInstance(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments())); - } - } - - public void Run() - { - Run(null); - } - - public void Run(int? columns) - { - ComponentCatalog.CacheClassesExtra(_extraAssemblies); - - using (var ch = _env.Start("Help")) - using (var sw = new StringWriter(CultureInfo.InvariantCulture)) - using (var writer = IndentingTextWriter.Wrap(sw)) - { - if (_listKinds) - { - if (_component != null) - writer.WriteLine("Listing component kinds so ignoring specified component"); - else if (_kind != null) - writer.WriteLine("Listing component kinds so ignoring specified kind"); - ListKinds(writer); - } - else if (_component != null) - ShowHelp(writer, columns); - else if (_allComponents) - ShowAllHelp(writer, columns); - else - ShowComponents(writer); - - ch.Info(sw.ToString()); - ch.Done(); - } - } - - private void ShowHelp(IndentingTextWriter writer, int? columns = null) - { - _env.AssertValue(_component); - - string name = _component.Trim(); - - string sig = _kind?.ToLowerInvariant(); - - // Note that we don't check IsHidden here. The current policy is when IsHidden is true, we don't - // show the item in "list all" functionality, but will still show help when explicitly requested. - - var infos = ComponentCatalog.FindLoadableClasses(name) - .OrderBy(x => ComponentCatalog.SignatureToString(x.SignatureTypes[0]).ToLowerInvariant()); - var kinds = new StringBuilder(); - var components = new List(); - foreach (var info in infos) - { - _env.AssertValue(info.SignatureTypes); - kinds.Clear(); - bool foundSig = false; - foreach (var signature in info.SignatureTypes) - { - _env.Assert(signature.BaseType == typeof(MulticastDelegate)); - - string kind; - if (signature == typeof(SignatureDefault)) - { - kind = "Component"; - if (sig == null || "default".StartsWithInvariantCulture(sig)) - foundSig = true; - } - else - { - kind = ComponentCatalog.SignatureToString(signature); - if (sig == null || kind.StartsWithInvariantCultureIgnoreCase(sig)) - foundSig = true; - } - - if (kinds.Length > 0) - kinds.Append(", "); - kinds.Append(kind); - } - if (foundSig) - { - string kindsStr = kinds.ToString(); - var args = info.CreateArguments(); - - ShowUsage(writer, kindsStr, info.Summary, info.LoadNames[0], info.LoadNames, args, columns); - components.Add(new Component(kindsStr, info, args)); - } - } - - if (components.Count == 0) - writer.WriteLine("Unknown component: '{0}'", name); - else - Serialize(components); - } - - private void ShowAllHelp(IndentingTextWriter writer, int? columns = null) - { - string sig = _kind?.ToLowerInvariant(); - - var infos = ComponentCatalog.GetAllClasses() - .OrderBy(info => info.LoadNames[0].ToLowerInvariant()) - .ThenBy(info => ComponentCatalog.SignatureToString(info.SignatureTypes[0]).ToLowerInvariant()); - var components = new List(); - foreach (var info in infos) - { - // REVIEW: We should only be printing the usage once, not for every signature. - _env.AssertValue(info.SignatureTypes); - foreach (var signature in info.SignatureTypes) - { - _env.Assert(signature.BaseType == typeof(MulticastDelegate)); - - string kind = ComponentCatalog.SignatureToString(signature); - if (sig != null && !kind.StartsWithInvariantCultureIgnoreCase(sig)) - continue; - - // Don't show classes that have no arguments. - var args = info.CreateArguments(); - if (args == null) - continue; - - ShowUsage(writer, kind, info.Summary, info.LoadNames[0], info.LoadNames, args, columns); - components.Add(new Component(kind, info, args)); - } - } - - if (components.Count > 0) - Serialize(components); - } - - private void ShowUsage(IndentingTextWriter writer, string kind, string summary, string loadName, - IReadOnlyList loadNames, object args, int? columns) - { - _env.Assert(loadName == loadNames[0]); - - writer.WriteLine("Help for {0}: '{1}'", kind, loadName); - using (writer.Nest()) - ShowAliases(writer, loadNames); - - writer.WriteLine(); - ShowFormattedSummary(writer, summary, columns); - - if (args == null) - { - writer.WriteLine("Component '{0}' is not configurable", loadName); - writer.WriteLine(); - } - else - writer.WriteLine(CmdParser.ArgumentsUsage(_env, args.GetType(), args, false, columns)); - } - - private void ShowComponents(IndentingTextWriter writer) - { - Type typeSig; - Type typeRes; - string kind; - - if (_kind == null) - { - // Show commands. - typeSig = typeof(SignatureCommand); - typeRes = typeof(ICommand); - kind = "Command"; - writer.WriteLine("Available commands:"); - } - else - { - kind = _kind.ToLowerInvariant(); - var sigs = ComponentCatalog.GetAllSignatureTypes(); - typeSig = sigs.FirstOrDefault(t => ComponentCatalog.SignatureToString(t).ToLowerInvariant() == kind); - if (typeSig == null) - { - typeSig = sigs.FirstOrDefault(t => ComponentCatalog.SignatureToString(t).StartsWithInvariantCultureIgnoreCase(kind)); - if (typeSig == null) - { - writer.WriteLine("Couldn't find kind '{0}'", kind); - ListKinds(writer); - return; - } - } - typeRes = typeof(object); - writer.WriteLine("Available components for kind '{0}':", ComponentCatalog.SignatureToString(typeSig)); - } - - var infos = ComponentCatalog.GetAllDerivedClasses(typeRes, typeSig) - .Where(x => !x.IsHidden) - .OrderBy(x => x.LoadNames[0].ToLowerInvariant()); - using (writer.Nest()) - { - var components = new List(); - foreach (var info in infos) - { - _env.Assert(info.LoadNames.Count > 0); - - writer.Write("{0}", info.LoadNames[0]); - if (!string.IsNullOrWhiteSpace(info.UserName)) - writer.Write(": {0}", info.UserName); - writer.WriteLine(); - - using (writer.Nest()) - ShowAliases(writer, info.LoadNames); - components.Add(new Component(kind, info, info.CreateArguments())); - } - - if (components.Count > 0) - Serialize(components); - } - } - - private void Serialize(List components) - { - _env.AssertValue(components); - - if (_generator != null) - GenerateModule(components); - } - - private void ShowAliases(IndentingTextWriter writer, IReadOnlyList names) - { - if (names.Count <= 1) - return; - - string pre = "Aliases: "; - for (int i = 1; i < names.Count; i++) - { - writer.Write(pre); - pre = ", "; - writer.Write(names[i]); - } - writer.WriteLine(); - } - - private void ListKinds(IndentingTextWriter writer) - { - var sigs = ComponentCatalog.GetAllSignatureTypes() - .Select(ComponentCatalog.SignatureToString) - .OrderBy(x => x); - - writer.WriteLine("Available component kinds:"); - using (writer.Nest()) - { - foreach (var sig in sigs) - writer.WriteLine(sig); - } - } - - private void ShowFormattedSummary(IndentingTextWriter writer, string summary, int? columns) - { - _env.AssertValue(writer); - - if (string.IsNullOrWhiteSpace(summary)) - return; - - // REVIEW: should we replace consecutive spaces with a single space as a preprocessing step? - int screenWidth = (columns ?? CmdParser.GetConsoleWindowWidth()) - 1; - - // GetConsoleWindowWidth returns 0 if command redirection operator is used - if (screenWidth <= 0) - screenWidth = 80; - - const int indentLen = 3; - string indent = new string(' ', indentLen); - var builder = new StringBuilder(); - - // REVIEW: is using StringSplitOptions.RemoveEmptyEntries the right thing to do here? - var blocks = summary.Split(new[] { "\n", "\r" }, StringSplitOptions.RemoveEmptyEntries); - for (int i = 0; i < blocks.Length; i++) - AppendFormattedText(builder, blocks[i], indent, screenWidth); - - writer.WriteLine("Summary:"); - writer.WriteLine(builder); - } - - private void AppendFormattedText(StringBuilder builder, string text, string indent, int screenWidth) - { - _env.AssertValue(builder); - _env.AssertNonEmpty(text); - _env.AssertNonEmpty(indent); - _env.Assert(screenWidth > 0); - - int textIdx = 0; - while (textIdx < text.Length) - { - int screenLeft = screenWidth - indent.Length; - int summaryLeft = text.Length - textIdx; - if (summaryLeft <= screenLeft) - { - builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine(); - break; - } - - int spaceIdx = text.LastIndexOf(' ', screenLeft + textIdx, screenLeft); - if (spaceIdx < 0) - { - // Print to the first space. - int startIdx = screenLeft + textIdx + 1; - spaceIdx = text.IndexOf(' ', startIdx, text.Length - startIdx); - if (spaceIdx < 0) - { - // Print to the end. - builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine(); - break; - } - } - - int appendCount = spaceIdx - textIdx; - builder.Append(indent).Append(text, textIdx, appendCount).AppendLine(); - textIdx += appendCount + 1; - } - } - - public struct Component - { - public readonly string Kind; - public readonly ComponentCatalog.LoadableClassInfo Info; - public readonly object Args; - - public Component(string kind, ComponentCatalog.LoadableClassInfo info, object args) - { - Contracts.AssertNonEmpty(kind); - Contracts.AssertValue(info); - Contracts.AssertValueOrNull(args); - - Kind = kind; - Info = info; - Args = args; - } - } - - private void GenerateModule(List components) - { - Contracts.AssertValue(components); - _generator.Generate(components); - } - } - - public sealed class XmlGenerator : IGenerator - { - public sealed class Arguments - { - [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The path of the XML documentation file", - ShortName = "xml", Hide = true)] - public string XmlFilename; - } - - private readonly string _xmlFilename; - private readonly IHost _host; - - public XmlGenerator(IHostEnvironment env, Arguments args, string regenerate) - { - Contracts.CheckValue(env, nameof(env)); - env.AssertValue(args, nameof(args)); - env.AssertNonEmpty(regenerate, nameof(regenerate)); - - _xmlFilename = args.XmlFilename; - if (!string.IsNullOrWhiteSpace(_xmlFilename)) - Utils.CheckOptionalUserDirectory(_xmlFilename, nameof(args.XmlFilename)); - else - _xmlFilename = null; - _host = env.Register("XML Generator"); - } - - public void Generate(IEnumerable infos) - { - if (_xmlFilename == null) - return; - using (var ch = _host.Start("Generating XML")) - { - var content = new XElement("Components", - from c in infos - where !string.IsNullOrWhiteSpace(c.Info.UserName) - select new XElement("Component", - new XAttribute("Kind", c.Kind), - new XElement("Name", c.Info.UserName), - string.IsNullOrWhiteSpace(c.Info.Summary) ? null : new XElement("Summary", c.Info.Summary), - new XElement("LoadNames", - from l in c.Info.LoadNames - select new XElement("LoadName", l)), - new XElement("Type", c.Info.Type.ToString()), - new XElement("SignatureTypes", - from s in c.Info.SignatureTypes - select new XElement("SignatureType", s.ToString())), - c.Args == null - ? null - : new XElement("Arguments", - from a in CmdParser.GetArgInfo(c.Args.GetType(), c.Args).Args - select new XElement("Argument", - new XElement("LongName", a.LongName), - a.ShortNames == null - ? null - : new XElement("ShortNames", - from sn in a.ShortNames - select new XElement("ShortName", sn)), - new XElement("HelpText", a.HelpText), - CreateDefaultValueElement(ch, c.Kind, a))))); - File.WriteAllText(_xmlFilename, content.ToString()); - ch.Done(); - } - } - - private XElement CreateDefaultValueElement(IChannel ch, string name, CmdParser.ArgInfo.Arg a) - { - if (a.DefaultValue == null) - return null; - if (a.DefaultValue is char) - { - char val = (char)a.DefaultValue; - if (!char.IsLetterOrDigit(val) && !char.IsPunctuation(val) && !char.IsSymbol(val)) - { - ch.Warning("Unprintable default value for component {0}, character valued field {1}: {2}", name, - a.LongName, Convert.ToUInt16(val).ToString("X4", CultureInfo.InvariantCulture)); - - return null; - } - } - return new XElement("DefaultValue", a.DefaultValue); - } - } -} diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj deleted file mode 100644 index 2ddbf9c7f7..0000000000 --- a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - true - CORECLR - netstandard2.0 - Microsoft.ML - - - - - - - \ No newline at end of file diff --git a/src/Microsoft.ML.Maml/VersionCommand.cs b/src/Microsoft.ML.Maml/VersionCommand.cs deleted file mode 100644 index 2ba0afe116..0000000000 --- a/src/Microsoft.ML.Maml/VersionCommand.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Reflection; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Command; -using Microsoft.ML.Runtime.Tools; - -[assembly: LoadableClass(VersionCommand.Summary, typeof(VersionCommand), null, typeof(SignatureCommand), - "Version Command", "Version")] - -namespace Microsoft.ML.Runtime.Tools -{ - public sealed class VersionCommand : ICommand - { - internal const string Summary = "Prints the TLC version."; - - private readonly IHost _host; - - public VersionCommand(IHostEnvironment env) - { - Contracts.CheckValue(env, nameof(env)); - - _host = env.Register("Version"); - } - - public void Run() - { - using (var ch = _host.Start("Version")) - { - string version = typeof(VersionCommand).GetTypeInfo().Assembly.GetName().Version.ToString(); - ch.Info(version); - ch.Done(); - } - } - } -} diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj index b420da0eb0..75597c7c54 100644 --- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj +++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj @@ -8,10 +8,10 @@ + - diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 5bb4782599..cb53256a73 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -238,6 +238,18 @@ public void Add(Microsoft.ML.Models.OneVersusAll input, Microsoft.ML.Models.OneV _jsonNodes.Add(Serialize("Models.OneVersusAll", input, output)); } + public Microsoft.ML.Models.OnnxConverter.Output Add(Microsoft.ML.Models.OnnxConverter input) + { + var output = new Microsoft.ML.Models.OnnxConverter.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Models.OnnxConverter input, Microsoft.ML.Models.OnnxConverter.Output output) + { + _jsonNodes.Add(Serialize("Models.OnnxConverter", input, output)); + } + public Microsoft.ML.Models.OvaModelCombiner.Output Add(Microsoft.ML.Models.OvaModelCombiner input) { var output = new Microsoft.ML.Models.OvaModelCombiner.Output(); @@ -2666,6 +2678,63 @@ public OneVersusAllPipelineStep(Output output) } } + namespace Models + { + + /// + /// Converts the model to ONNX format. + /// + public sealed partial class OnnxConverter + { + + + /// + /// The path to write the output ONNX to. + /// + public string Onnx { get; set; } + + /// + /// The path to write the output JSON to. + /// + public string Json { get; set; } + + /// + /// The 'name' property in the output ONNX. By default this will be the ONNX extension-less name. + /// + public string Name { get; set; } + + /// + /// The 'domain' property in the output ONNX. + /// + public string Domain { get; set; } + + /// + /// Array of input column names to drop + /// + public string[] InputsToDrop { get; set; } + + /// + /// Array of output column names to drop + /// + public string[] OutputsToDrop { get; set; } + + /// + /// Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present. + /// + public bool? LoadPredictor { get; set; } + + /// + /// Model that needs to be converted to ONNX format. + /// + public Var Model { get; set; } = new Var(); + + + public sealed class Output + { + } + } + } + namespace Models { diff --git a/src/Microsoft.ML/Microsoft.ML.csproj b/src/Microsoft.ML/Microsoft.ML.csproj index 7b370f2804..4f3f8f0dc2 100644 --- a/src/Microsoft.ML/Microsoft.ML.csproj +++ b/src/Microsoft.ML/Microsoft.ML.csproj @@ -14,9 +14,9 @@ + - diff --git a/src/Microsoft.ML/Models/SaveAsOnnx.cs b/src/Microsoft.ML/Models/SaveAsOnnx.cs deleted file mode 100644 index 8165de521e..0000000000 --- a/src/Microsoft.ML/Models/SaveAsOnnx.cs +++ /dev/null @@ -1,24 +0,0 @@ -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Model.Onnx; - -namespace Microsoft.ML.Models -{ - /// - /// Converts a model to ONNX format. - /// - public sealed class SaveAsOnnx - { - /// - /// Converts and then saves a model to ONNX format. - /// - /// Arguments such as input model file path, output ONNX file path, etc. - public static void Save(SaveOnnxCommand.Arguments args) - { - using (var env = new TlcEnvironment()) - { - var cmd = new SaveOnnxCommand(env, args); - cmd.Run(); - } - } - } -} diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index 6eb1b3c6f5..1e4d457b8c 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -25,11 +25,26 @@ internal PredictionModel(Stream stream) _predictorModel = new Runtime.EntryPoints.TransformModel(_env, stream); } - internal Runtime.EntryPoints.TransformModel PredictorModel + internal TransformModel PredictorModel { get { return _predictorModel; } } + /// + /// Converts the model to ONNX format. + /// + /// Arguments to ONNX converter. + public void ExportToOnnx(Models.OnnxConverter onnxConverter) + { + _env.CheckValue(onnxConverter, nameof(onnxConverter)); + + Experiment experiment = _env.CreateExperiment(); + experiment.Add(onnxConverter); + experiment.Compile(); + experiment.SetInput(onnxConverter.Model, new PredictorModel(_env, _predictorModel)); + experiment.Run(); + } + /// /// Returns labels that correspond to indices of the score array in the case of /// multi-class classification problem. @@ -44,7 +59,7 @@ public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = D int colIndex = -1; if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex)) return false; - + int expectedLabelCount = schema.GetColumnType(colIndex).ValueCount; if (!schema.HasSlotNames(colIndex, expectedLabelCount)) return false; @@ -57,7 +72,7 @@ public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = D names = new string[expectedLabelCount]; int index = 0; - foreach(var label in labels.DenseValues()) + foreach (var label in labels.DenseValues()) names[index++] = label.ToString(); return true; diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index d9cf8a2f29..6e4a6dc85f 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -9,7 +9,7 @@ - + diff --git a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs index c6e870a1f2..ebe9513e64 100644 --- a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs +++ b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs @@ -12,9 +12,8 @@ public partial class ScenariosTests { public class BreastCancerData { - [Column(ordinal: "0")] public float Label; - [Column(ordinal: "1-9")] + [VectorType(9)] public float[] Features; } @@ -57,36 +56,29 @@ public void SaveModelToOnnxTest() }); pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); - - PredictionModel model = pipeline.Train(); - - var modelOutpath = GetOutputPath(Path.Combine("..", "Common", - "Scenario", "BinaryClassification", "BreastCancer"), "SaveModelToOnnxTest.zip"); + var model = pipeline.Train(); + var subDir = Path.Combine("..", "Common", "Scenario", "BinaryClassification", "BreastCancer"); + var modelOutpath = GetOutputPath(subDir, "SaveModelToOnnxTest.zip"); DeleteOutputPath(modelOutpath); - var onnxPath = GetOutputPath(Path.Combine("..", "Common", - "Scenario", "BinaryClassification", "BreastCancer"), "SaveModelToOnnxTest.pb"); - + var onnxPath = GetOutputPath(subDir, "SaveModelToOnnxTest.pb"); DeleteOutputPath(onnxPath); - var onnxAsJsonPath = GetOutputPath(Path.Combine("..", "Common", - "Scenario", "BinaryClassification", "BreastCancer"), "SaveModelToOnnxTest.json"); - + var onnxAsJsonPath = GetOutputPath(subDir, "SaveModelToOnnxTest.json"); DeleteOutputPath(onnxAsJsonPath); - model.WriteAsync(modelOutpath); - SaveAsOnnx.Save(new Runtime.Model.Onnx.SaveOnnxCommand.Arguments + model.ExportToOnnx(new OnnxConverter() { - InputModelFile = modelOutpath, - OutputsToDrop = "Label,Features", + InputsToDrop = new[] { "Label" }, + OutputsToDrop = new[] { "Label", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" }); - Assert.True(CheckEquality(Path.Combine("..", "Common", "Scenario", "BinaryClassification", "BreastCancer"), - "SaveModelToOnnxTest.json")); + CheckEquality(subDir, "SaveModelToOnnxTest.json"); + Done(); } } } From b04ef49ff2a5865d035ea7c45cd8d0288fb13381 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 00:41:11 -0700 Subject: [PATCH 04/22] PR feedback. --- src/Microsoft.ML.Commands/ChainCommand.cs | 84 +++ src/Microsoft.ML.Commands/HelpCommand.cs | 509 ++++++++++++++++++ .../Microsoft.ML.Commands.csproj | 14 + .../Properties/AssemblyInfo.cs | 9 + src/Microsoft.ML.Commands/VersionCommand.cs | 38 ++ .../Microsoft.ML.Console.csproj | 17 + 6 files changed, 671 insertions(+) create mode 100644 src/Microsoft.ML.Commands/ChainCommand.cs create mode 100644 src/Microsoft.ML.Commands/HelpCommand.cs create mode 100644 src/Microsoft.ML.Commands/Microsoft.ML.Commands.csproj create mode 100644 src/Microsoft.ML.Commands/Properties/AssemblyInfo.cs create mode 100644 src/Microsoft.ML.Commands/VersionCommand.cs create mode 100644 src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj diff --git a/src/Microsoft.ML.Commands/ChainCommand.cs b/src/Microsoft.ML.Commands/ChainCommand.cs new file mode 100644 index 0000000000..829923a60c --- /dev/null +++ b/src/Microsoft.ML.Commands/ChainCommand.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Globalization; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Command; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Tools; + +[assembly: LoadableClass(ChainCommand.Summary, typeof(ChainCommand), typeof(ChainCommand.Arguments), typeof(SignatureCommand), + "Chain Command", "Chain")] + +namespace Microsoft.ML.Runtime.Tools +{ + using Stopwatch = System.Diagnostics.Stopwatch; + + public sealed class ChainCommand : ICommand + { + public sealed class Arguments + { +#pragma warning disable 649 // never assigned + [Argument(ArgumentType.Multiple, HelpText = "Command", ShortName = "cmd")] + public SubComponent[] Command; +#pragma warning restore 649 // never assigned + } + + internal const string Summary = "A command that chains multiple other commands."; + + private readonly IHost _host; + + private readonly Arguments _args; + + public ChainCommand(IHostEnvironment env, Arguments args) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + + _args = args; + _host = env.Register("Chain"); + } + + public void Run() + { + using (var ch = _host.Start("Run")) + { + var sw = new Stopwatch(); + int count = 0; + + sw.Start(); + if (_args.Command != null) + { + for (int i = 0; i < _args.Command.Length; i++) + { + using (var chCmd = _host.Start(string.Format(CultureInfo.InvariantCulture, "Command[{0}]", i))) + { + var sub = _args.Command[i]; + + chCmd.Info("====================================================================================="); + chCmd.Info("Executing: {0}", sub); + chCmd.Info("====================================================================================="); + + var cmd = sub.CreateInstance(_host); + cmd.Run(); + count++; + + chCmd.Info(" "); + + chCmd.Done(); + } + } + } + sw.Stop(); + + ch.Info("====================================================================================="); + ch.Info("Executed {0} commands in {1}", count, sw.Elapsed); + ch.Info("====================================================================================="); + + ch.Done(); + } + } + } +} diff --git a/src/Microsoft.ML.Commands/HelpCommand.cs b/src/Microsoft.ML.Commands/HelpCommand.cs new file mode 100644 index 0000000000..e0941f5a93 --- /dev/null +++ b/src/Microsoft.ML.Commands/HelpCommand.cs @@ -0,0 +1,509 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Command; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Tools; + +[assembly: LoadableClass(HelpCommand.Summary, typeof(HelpCommand), typeof(HelpCommand.Arguments), typeof(SignatureCommand), + "MAML Help Command", "Help", "?")] + +[assembly: LoadableClass(typeof(XmlGenerator), typeof(XmlGenerator.Arguments), typeof(SignatureModuleGenerator), + "Xml generator", "XmlGenerator", "Xml")] + +namespace Microsoft.ML.Runtime.Tools +{ + public interface IGenerator + { + void Generate(IEnumerable infos); + } + + public delegate void SignatureModuleGenerator(string regenerate); + + public sealed class HelpCommand : ICommand + { + public sealed class Arguments + { +#pragma warning disable 649 // never assigned + [DefaultArgument(ArgumentType.AtMostOnce, HelpText = "The component name to get help for")] + public string Component; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The kind of component to look for", ShortName = "kind")] + public string Kind; + + [Argument(ArgumentType.AtMostOnce, HelpText = "List the component kinds", ShortName = "list")] + public bool ListKinds; + + [Argument(ArgumentType.AtMostOnce, ShortName = "all", Hide = true)] + public bool AllComponents; + + // extra DLLs for dynamic loading + [Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")] + public string[] ExtraAssemblies; + + [Argument(ArgumentType.LastOccurenceWins, Hide = true)] + public SubComponent Generator; +#pragma warning restore 649 // never assigned + } + + internal const string Summary = "Prints command line help."; + + private readonly IHostEnvironment _env; + private readonly string _component; + private readonly string _kind; + private readonly bool _listKinds; + private readonly bool _allComponents; + private readonly string[] _extraAssemblies; + private readonly IGenerator _generator; + + // REVIEW: Need to change the help command to use the provided host environment for output, + // instead of assuming the console. + public HelpCommand(IHostEnvironment env, Arguments args) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + + _env = env; + _component = args.Component; + if (string.IsNullOrWhiteSpace(_component)) + _component = null; + + _kind = args.Kind; + if (string.IsNullOrWhiteSpace(_kind)) + _kind = null; + + _listKinds = args.ListKinds; + _allComponents = args.AllComponents; + + _extraAssemblies = args.ExtraAssemblies; + + if (args.Generator.IsGood()) + { + _generator = args.Generator.CreateInstance(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments())); + } + } + + public void Run() + { + Run(null); + } + + public void Run(int? columns) + { + ComponentCatalog.CacheClassesExtra(_extraAssemblies); + + using (var ch = _env.Start("Help")) + using (var sw = new StringWriter(CultureInfo.InvariantCulture)) + using (var writer = IndentingTextWriter.Wrap(sw)) + { + if (_listKinds) + { + if (_component != null) + writer.WriteLine("Listing component kinds so ignoring specified component"); + else if (_kind != null) + writer.WriteLine("Listing component kinds so ignoring specified kind"); + ListKinds(writer); + } + else if (_component != null) + ShowHelp(writer, columns); + else if (_allComponents) + ShowAllHelp(writer, columns); + else + ShowComponents(writer); + + ch.Info(sw.ToString()); + ch.Done(); + } + } + + private void ShowHelp(IndentingTextWriter writer, int? columns = null) + { + _env.AssertValue(_component); + + string name = _component.Trim(); + + string sig = _kind?.ToLowerInvariant(); + + // Note that we don't check IsHidden here. The current policy is when IsHidden is true, we don't + // show the item in "list all" functionality, but will still show help when explicitly requested. + + var infos = ComponentCatalog.FindLoadableClasses(name) + .OrderBy(x => ComponentCatalog.SignatureToString(x.SignatureTypes[0]).ToLowerInvariant()); + var kinds = new StringBuilder(); + var components = new List(); + foreach (var info in infos) + { + _env.AssertValue(info.SignatureTypes); + kinds.Clear(); + bool foundSig = false; + foreach (var signature in info.SignatureTypes) + { + _env.Assert(signature.BaseType == typeof(MulticastDelegate)); + + string kind; + if (signature == typeof(SignatureDefault)) + { + kind = "Component"; + if (sig == null || "default".StartsWithInvariantCulture(sig)) + foundSig = true; + } + else + { + kind = ComponentCatalog.SignatureToString(signature); + if (sig == null || kind.StartsWithInvariantCultureIgnoreCase(sig)) + foundSig = true; + } + + if (kinds.Length > 0) + kinds.Append(", "); + kinds.Append(kind); + } + if (foundSig) + { + string kindsStr = kinds.ToString(); + var args = info.CreateArguments(); + + ShowUsage(writer, kindsStr, info.Summary, info.LoadNames[0], info.LoadNames, args, columns); + components.Add(new Component(kindsStr, info, args)); + } + } + + if (components.Count == 0) + writer.WriteLine("Unknown component: '{0}'", name); + else + Serialize(components); + } + + private void ShowAllHelp(IndentingTextWriter writer, int? columns = null) + { + string sig = _kind?.ToLowerInvariant(); + + var infos = ComponentCatalog.GetAllClasses() + .OrderBy(info => info.LoadNames[0].ToLowerInvariant()) + .ThenBy(info => ComponentCatalog.SignatureToString(info.SignatureTypes[0]).ToLowerInvariant()); + var components = new List(); + foreach (var info in infos) + { + // REVIEW: We should only be printing the usage once, not for every signature. + _env.AssertValue(info.SignatureTypes); + foreach (var signature in info.SignatureTypes) + { + _env.Assert(signature.BaseType == typeof(MulticastDelegate)); + + string kind = ComponentCatalog.SignatureToString(signature); + if (sig != null && !kind.StartsWithInvariantCultureIgnoreCase(sig)) + continue; + + // Don't show classes that have no arguments. + var args = info.CreateArguments(); + if (args == null) + continue; + + ShowUsage(writer, kind, info.Summary, info.LoadNames[0], info.LoadNames, args, columns); + components.Add(new Component(kind, info, args)); + } + } + + if (components.Count > 0) + Serialize(components); + } + + private void ShowUsage(IndentingTextWriter writer, string kind, string summary, string loadName, + IReadOnlyList loadNames, object args, int? columns) + { + _env.Assert(loadName == loadNames[0]); + + writer.WriteLine("Help for {0}: '{1}'", kind, loadName); + using (writer.Nest()) + ShowAliases(writer, loadNames); + + writer.WriteLine(); + ShowFormattedSummary(writer, summary, columns); + + if (args == null) + { + writer.WriteLine("Component '{0}' is not configurable", loadName); + writer.WriteLine(); + } + else + writer.WriteLine(CmdParser.ArgumentsUsage(_env, args.GetType(), args, false, columns)); + } + + private void ShowComponents(IndentingTextWriter writer) + { + Type typeSig; + Type typeRes; + string kind; + + if (_kind == null) + { + // Show commands. + typeSig = typeof(SignatureCommand); + typeRes = typeof(ICommand); + kind = "Command"; + writer.WriteLine("Available commands:"); + } + else + { + kind = _kind.ToLowerInvariant(); + var sigs = ComponentCatalog.GetAllSignatureTypes(); + typeSig = sigs.FirstOrDefault(t => ComponentCatalog.SignatureToString(t).ToLowerInvariant() == kind); + if (typeSig == null) + { + typeSig = sigs.FirstOrDefault(t => ComponentCatalog.SignatureToString(t).StartsWithInvariantCultureIgnoreCase(kind)); + if (typeSig == null) + { + writer.WriteLine("Couldn't find kind '{0}'", kind); + ListKinds(writer); + return; + } + } + typeRes = typeof(object); + writer.WriteLine("Available components for kind '{0}':", ComponentCatalog.SignatureToString(typeSig)); + } + + var infos = ComponentCatalog.GetAllDerivedClasses(typeRes, typeSig) + .Where(x => !x.IsHidden) + .OrderBy(x => x.LoadNames[0].ToLowerInvariant()); + using (writer.Nest()) + { + var components = new List(); + foreach (var info in infos) + { + _env.Assert(info.LoadNames.Count > 0); + + writer.Write("{0}", info.LoadNames[0]); + if (!string.IsNullOrWhiteSpace(info.UserName)) + writer.Write(": {0}", info.UserName); + writer.WriteLine(); + + using (writer.Nest()) + ShowAliases(writer, info.LoadNames); + components.Add(new Component(kind, info, info.CreateArguments())); + } + + if (components.Count > 0) + Serialize(components); + } + } + + private void Serialize(List components) + { + _env.AssertValue(components); + + if (_generator != null) + GenerateModule(components); + } + + private void ShowAliases(IndentingTextWriter writer, IReadOnlyList names) + { + if (names.Count <= 1) + return; + + string pre = "Aliases: "; + for (int i = 1; i < names.Count; i++) + { + writer.Write(pre); + pre = ", "; + writer.Write(names[i]); + } + writer.WriteLine(); + } + + private void ListKinds(IndentingTextWriter writer) + { + var sigs = ComponentCatalog.GetAllSignatureTypes() + .Select(ComponentCatalog.SignatureToString) + .OrderBy(x => x); + + writer.WriteLine("Available component kinds:"); + using (writer.Nest()) + { + foreach (var sig in sigs) + writer.WriteLine(sig); + } + } + + private void ShowFormattedSummary(IndentingTextWriter writer, string summary, int? columns) + { + _env.AssertValue(writer); + + if (string.IsNullOrWhiteSpace(summary)) + return; + + // REVIEW: should we replace consecutive spaces with a single space as a preprocessing step? + int screenWidth = (columns ?? CmdParser.GetConsoleWindowWidth()) - 1; + + // GetConsoleWindowWidth returns 0 if command redirection operator is used + if (screenWidth <= 0) + screenWidth = 80; + + const int indentLen = 3; + string indent = new string(' ', indentLen); + var builder = new StringBuilder(); + + // REVIEW: is using StringSplitOptions.RemoveEmptyEntries the right thing to do here? + var blocks = summary.Split(new[] { "\n", "\r" }, StringSplitOptions.RemoveEmptyEntries); + for (int i = 0; i < blocks.Length; i++) + AppendFormattedText(builder, blocks[i], indent, screenWidth); + + writer.WriteLine("Summary:"); + writer.WriteLine(builder); + } + + private void AppendFormattedText(StringBuilder builder, string text, string indent, int screenWidth) + { + _env.AssertValue(builder); + _env.AssertNonEmpty(text); + _env.AssertNonEmpty(indent); + _env.Assert(screenWidth > 0); + + int textIdx = 0; + while (textIdx < text.Length) + { + int screenLeft = screenWidth - indent.Length; + int summaryLeft = text.Length - textIdx; + if (summaryLeft <= screenLeft) + { + builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine(); + break; + } + + int spaceIdx = text.LastIndexOf(' ', screenLeft + textIdx, screenLeft); + if (spaceIdx < 0) + { + // Print to the first space. + int startIdx = screenLeft + textIdx + 1; + spaceIdx = text.IndexOf(' ', startIdx, text.Length - startIdx); + if (spaceIdx < 0) + { + // Print to the end. + builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine(); + break; + } + } + + int appendCount = spaceIdx - textIdx; + builder.Append(indent).Append(text, textIdx, appendCount).AppendLine(); + textIdx += appendCount + 1; + } + } + + public struct Component + { + public readonly string Kind; + public readonly ComponentCatalog.LoadableClassInfo Info; + public readonly object Args; + + public Component(string kind, ComponentCatalog.LoadableClassInfo info, object args) + { + Contracts.AssertNonEmpty(kind); + Contracts.AssertValue(info); + Contracts.AssertValueOrNull(args); + + Kind = kind; + Info = info; + Args = args; + } + } + + private void GenerateModule(List components) + { + Contracts.AssertValue(components); + _generator.Generate(components); + } + } + + public sealed class XmlGenerator : IGenerator + { + public sealed class Arguments + { + [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The path of the XML documentation file", + ShortName = "xml", Hide = true)] + public string XmlFilename; + } + + private readonly string _xmlFilename; + private readonly IHost _host; + + public XmlGenerator(IHostEnvironment env, Arguments args, string regenerate) + { + Contracts.CheckValue(env, nameof(env)); + env.AssertValue(args, nameof(args)); + env.AssertNonEmpty(regenerate, nameof(regenerate)); + + _xmlFilename = args.XmlFilename; + if (!string.IsNullOrWhiteSpace(_xmlFilename)) + Utils.CheckOptionalUserDirectory(_xmlFilename, nameof(args.XmlFilename)); + else + _xmlFilename = null; + _host = env.Register("XML Generator"); + } + + public void Generate(IEnumerable infos) + { + if (_xmlFilename == null) + return; + using (var ch = _host.Start("Generating XML")) + { + var content = new XElement("Components", + from c in infos + where !string.IsNullOrWhiteSpace(c.Info.UserName) + select new XElement("Component", + new XAttribute("Kind", c.Kind), + new XElement("Name", c.Info.UserName), + string.IsNullOrWhiteSpace(c.Info.Summary) ? null : new XElement("Summary", c.Info.Summary), + new XElement("LoadNames", + from l in c.Info.LoadNames + select new XElement("LoadName", l)), + new XElement("Type", c.Info.Type.ToString()), + new XElement("SignatureTypes", + from s in c.Info.SignatureTypes + select new XElement("SignatureType", s.ToString())), + c.Args == null + ? null + : new XElement("Arguments", + from a in CmdParser.GetArgInfo(c.Args.GetType(), c.Args).Args + select new XElement("Argument", + new XElement("LongName", a.LongName), + a.ShortNames == null + ? null + : new XElement("ShortNames", + from sn in a.ShortNames + select new XElement("ShortName", sn)), + new XElement("HelpText", a.HelpText), + CreateDefaultValueElement(ch, c.Kind, a))))); + File.WriteAllText(_xmlFilename, content.ToString()); + ch.Done(); + } + } + + private XElement CreateDefaultValueElement(IChannel ch, string name, CmdParser.ArgInfo.Arg a) + { + if (a.DefaultValue == null) + return null; + if (a.DefaultValue is char) + { + char val = (char)a.DefaultValue; + if (!char.IsLetterOrDigit(val) && !char.IsPunctuation(val) && !char.IsSymbol(val)) + { + ch.Warning("Unprintable default value for component {0}, character valued field {1}: {2}", name, + a.LongName, Convert.ToUInt16(val).ToString("X4", CultureInfo.InvariantCulture)); + + return null; + } + } + return new XElement("DefaultValue", a.DefaultValue); + } + } +} diff --git a/src/Microsoft.ML.Commands/Microsoft.ML.Commands.csproj b/src/Microsoft.ML.Commands/Microsoft.ML.Commands.csproj new file mode 100644 index 0000000000..88b1e9e55b --- /dev/null +++ b/src/Microsoft.ML.Commands/Microsoft.ML.Commands.csproj @@ -0,0 +1,14 @@ + + + + true + CORECLR + Microsoft.ML + netstandard2.0 + + + + + + + \ No newline at end of file diff --git a/src/Microsoft.ML.Commands/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Commands/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..97db2a8a07 --- /dev/null +++ b/src/Microsoft.ML.Commands/Properties/AssemblyInfo.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] diff --git a/src/Microsoft.ML.Commands/VersionCommand.cs b/src/Microsoft.ML.Commands/VersionCommand.cs new file mode 100644 index 0000000000..2ba0afe116 --- /dev/null +++ b/src/Microsoft.ML.Commands/VersionCommand.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Reflection; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Command; +using Microsoft.ML.Runtime.Tools; + +[assembly: LoadableClass(VersionCommand.Summary, typeof(VersionCommand), null, typeof(SignatureCommand), + "Version Command", "Version")] + +namespace Microsoft.ML.Runtime.Tools +{ + public sealed class VersionCommand : ICommand + { + internal const string Summary = "Prints the TLC version."; + + private readonly IHost _host; + + public VersionCommand(IHostEnvironment env) + { + Contracts.CheckValue(env, nameof(env)); + + _host = env.Register("Version"); + } + + public void Run() + { + using (var ch = _host.Start("Version")) + { + string version = typeof(VersionCommand).GetTypeInfo().Assembly.GetName().Version.ToString(); + ch.Info(version); + ch.Done(); + } + } + } +} diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj new file mode 100644 index 0000000000..49742313e6 --- /dev/null +++ b/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj @@ -0,0 +1,17 @@ + + + + true + CORECLR + Microsoft.ML + netcoreapp2.0 + Exe + Microsoft.ML.Runtime.Tools.Maml + + + + + + + + \ No newline at end of file From de02e1e7bf231813e11b5c676f1015f537beefe3 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 00:56:37 -0700 Subject: [PATCH 05/22] resolve merge issues. --- .../Common/EntryPoints/core_ep-list.tsv | 1 + .../Common/EntryPoints/core_manifest.json | 89 +++++++++++++++++++ .../BreastCancer/SaveModelToOnnxTest.json | 0 .../Scenarios/BinaryClassification.cs | 2 +- 4 files changed, 91 insertions(+), 1 deletion(-) rename {ZBaselines => test/BaselineOutput}/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json (100%) diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 61f2604a8d..5efa7e2013 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -17,6 +17,7 @@ Models.FixedPlattCalibrator Apply a Platt calibrator with a fixed slope and offs Models.MultiOutputRegressionEvaluator Evaluates a multi output regression scored dataset. Microsoft.ML.Runtime.Data.Evaluate MultiOutputRegression Microsoft.ML.Runtime.Data.MultiOutputRegressionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.NaiveCalibrator Apply a Naive calibrator to an input model Microsoft.ML.Runtime.Internal.Calibration.Calibrate Naive Microsoft.ML.Runtime.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CalibratorOutput Models.OneVersusAll One-vs-All macro (OVA) Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro OVA Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro+Output] +Models.OnnxConverter Converts the model to ONNX format. Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand Apply Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Input Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Output Models.OvaModelCombiner Combines a sequence of PredictorModels into a single model Microsoft.ML.Runtime.Learners.OvaPredictor CombineOvaModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+CombineOvaPredictorModelsInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput Models.PAVCalibrator Apply a PAV calibrator to an input model Microsoft.ML.Runtime.Internal.Calibration.Calibrate Pav Microsoft.ML.Runtime.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CalibratorOutput Models.PipelineSweeper AutoML pipeline sweeping optimzation macro. Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro PipelineSweep Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro+Output] diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 010d4a0afa..35e2064bda 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2055,6 +2055,95 @@ "ITrainerInput" ] }, + { + "Name": "Models.OnnxConverter", + "Desc": "Converts the model to ONNX format.", + "FriendlyName": "ONNX Converter.", + "ShortName": null, + "Inputs": [ + { + "Name": "Onnx", + "Type": "String", + "Desc": "The path to write the output ONNX to.", + "Required": false, + "SortOrder": 1.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Json", + "Type": "String", + "Desc": "The path to write the output JSON to.", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Name", + "Type": "String", + "Desc": "The 'name' property in the output ONNX. By default this will be the ONNX extension-less name.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Domain", + "Type": "String", + "Desc": "The 'domain' property in the output ONNX.", + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "InputsToDrop", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Array of input column names to drop", + "Required": false, + "SortOrder": 5.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "OutputsToDrop", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Array of output column names to drop", + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "LoadPredictor", + "Type": "Bool", + "Desc": "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", + "Aliases": [ + "pred" + ], + "Required": false, + "SortOrder": 7.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "Model", + "Type": "PredictorModel", + "Desc": "Model that needs to be converted to ONNX format.", + "Required": true, + "SortOrder": 8.0, + "IsNullable": false + } + ], + "Outputs": [] + }, { "Name": "Models.OvaModelCombiner", "Desc": "Combines a sequence of PredictorModels into a single model", diff --git a/ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json similarity index 100% rename from ZBaselines/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json rename to test/BaselineOutput/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json diff --git a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs index ebe9513e64..767e0e835c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs +++ b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs @@ -58,7 +58,7 @@ public void SaveModelToOnnxTest() pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); var model = pipeline.Train(); - var subDir = Path.Combine("..", "Common", "Scenario", "BinaryClassification", "BreastCancer"); + var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Scenario", "BinaryClassification", "BreastCancer"); var modelOutpath = GetOutputPath(subDir, "SaveModelToOnnxTest.zip"); DeleteOutputPath(modelOutpath); From 6f4434e2772adeffa97e36b3123e9dfb7c29b4aa Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 01:00:16 -0700 Subject: [PATCH 06/22] cleanup. --- .../Microsoft.ML.MamlExec.csproj | 28 ------------------- 1 file changed, 28 deletions(-) delete mode 100644 src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj diff --git a/src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj b/src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj deleted file mode 100644 index 667aedf9b0..0000000000 --- a/src/Microsoft.ML.MamlExec/Microsoft.ML.MamlExec.csproj +++ /dev/null @@ -1,28 +0,0 @@ - - - - true - CORECLR - netcoreapp2.0 - Microsoft.ML - Exe - Microsoft.ML.Runtime.Tools.Maml - - - - - - MAML.cs - - - ChainCommand.cs - - - HelpCommand.cs - - - VersionCommand.cs - - - - \ No newline at end of file From 6c974163291148331740ea3626794e54c86a0d23 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 08:26:17 -0700 Subject: [PATCH 07/22] cleanup. --- src/Microsoft.ML/Models/OnnxConverter.cs | 27 +++++++++++++++++++ src/Microsoft.ML/PredictionModel.cs | 15 ----------- .../Scenarios/BinaryClassification.cs | 6 +++-- 3 files changed, 31 insertions(+), 17 deletions(-) create mode 100644 src/Microsoft.ML/Models/OnnxConverter.cs diff --git a/src/Microsoft.ML/Models/OnnxConverter.cs b/src/Microsoft.ML/Models/OnnxConverter.cs new file mode 100644 index 0000000000..420ff15a28 --- /dev/null +++ b/src/Microsoft.ML/Models/OnnxConverter.cs @@ -0,0 +1,27 @@ +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; + +namespace Microsoft.ML.Models +{ + public sealed partial class OnnxConverter + { + /// + /// Converts the model to ONNX format. + /// + /// Model that needs to be converted to ONNX format. + public void Convert(PredictionModel model) + { + using (var environment = new TlcEnvironment()) + { + environment.CheckValue(model, nameof(model)); + + Experiment experiment = environment.CreateExperiment(); + experiment.Add(this); + experiment.Compile(); + experiment.SetInput(Model, new PredictorModel(environment, model.PredictorModel)); + experiment.Run(); + } + } + } +} diff --git a/src/Microsoft.ML/PredictionModel.cs b/src/Microsoft.ML/PredictionModel.cs index 1e4d457b8c..c1dded82b8 100644 --- a/src/Microsoft.ML/PredictionModel.cs +++ b/src/Microsoft.ML/PredictionModel.cs @@ -30,21 +30,6 @@ internal TransformModel PredictorModel get { return _predictorModel; } } - /// - /// Converts the model to ONNX format. - /// - /// Arguments to ONNX converter. - public void ExportToOnnx(Models.OnnxConverter onnxConverter) - { - _env.CheckValue(onnxConverter, nameof(onnxConverter)); - - Experiment experiment = _env.CreateExperiment(); - experiment.Add(onnxConverter); - experiment.Compile(); - experiment.SetInput(onnxConverter.Model, new PredictorModel(_env, _predictorModel)); - experiment.Run(); - } - /// /// Returns labels that correspond to indices of the score array in the case of /// multi-class classification problem. diff --git a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs index 767e0e835c..7d90663e39 100644 --- a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs +++ b/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs @@ -68,14 +68,16 @@ public void SaveModelToOnnxTest() var onnxAsJsonPath = GetOutputPath(subDir, "SaveModelToOnnxTest.json"); DeleteOutputPath(onnxAsJsonPath); - model.ExportToOnnx(new OnnxConverter() + OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" - }); + }; + + converter.Convert(model); CheckEquality(subDir, "SaveModelToOnnxTest.json"); Done(); From 534fcd1a1e9d40b84ff473a0c017f1ab05b25002 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 13:36:20 -0700 Subject: [PATCH 08/22] PR feedback. --- Microsoft.ML.sln | 2 +- .../MAML.cs | 0 ...mmands.csproj => Microsoft.ML.Maml.csproj} | 0 src/Microsoft.ML.Core/Data/ITransformModel.cs | 10 ---- .../EntryPoints/PredictorModel.cs | 12 ----- .../EntryPoints/TransformModel.cs | 10 ---- .../Model/Onnx/SaveOnnxCommand.cs | 51 ++++++------------- src/Microsoft.ML.Maml/Console.cs | 21 ++++++++ .../Microsoft.ML.Console.csproj | 4 +- .../Microsoft.ML.ResultProcessor.csproj | 2 +- src/Microsoft.ML/CSharpApi.cs | 7 +-- src/Microsoft.ML/Microsoft.ML.csproj | 2 +- src/Microsoft.ML/Models/OnnxConverter.cs | 9 ++-- .../Microsoft.ML.TestFramework.csproj | 2 +- 14 files changed, 50 insertions(+), 82 deletions(-) rename src/{Microsoft.ML.Maml => Microsoft.ML.Commands}/MAML.cs (100%) rename src/Microsoft.ML.Commands/{Microsoft.ML.Commands.csproj => Microsoft.ML.Maml.csproj} (100%) create mode 100644 src/Microsoft.ML.Maml/Console.cs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 91aa04754e..f4c7ae95ed 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -107,7 +107,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Commands", "src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj", "{C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Commands\Microsoft.ML.Maml.csproj", "{C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Commands/MAML.cs similarity index 100% rename from src/Microsoft.ML.Maml/MAML.cs rename to src/Microsoft.ML.Commands/MAML.cs diff --git a/src/Microsoft.ML.Commands/Microsoft.ML.Commands.csproj b/src/Microsoft.ML.Commands/Microsoft.ML.Maml.csproj similarity index 100% rename from src/Microsoft.ML.Commands/Microsoft.ML.Commands.csproj rename to src/Microsoft.ML.Commands/Microsoft.ML.Maml.csproj diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index 6075a7fe1c..ccc73265ec 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -32,16 +32,6 @@ public interface ITransformModel /// ISchema OutputSchema { get; } - /// - /// This contains the transforms to save instantiated on an IDataView with - /// appropriate initial schema. Note that the "root" of this is typically either - /// an empty IDataView or a BinaryLoader with no rows. However, other root - /// types are possible, since we don't insist on this when loading a model - /// from a zip file. However, whenever we save, we force a BinaryLoader to - /// be serialized for the root. - /// - IDataView View { get; } - /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs index 5f3c92c029..af726fa758 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModel.cs @@ -36,18 +36,6 @@ public PredictorModel(IHostEnvironment env, RoleMappedData trainingData, IDataVi _predictor = predictor; } - //REVIEW: I'm not sure this is the right thing to do because we are setting predictor to null - //when this class is supposed to contain a predictor. TransformModel may or may not - //contain a predictor. Here we are just using this class as a wrapper for TransformModel - //so that we can use a single class to accept TransformModel and PredictorModel has inputs. - public PredictorModel(IHostEnvironment env, ITransformModel transformModel) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(transformModel, nameof(transformModel)); - - _transformModel = transformModel; - } - public PredictorModel(IHostEnvironment env, Stream stream) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index 783668f9d3..9edc87df6d 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -48,16 +48,6 @@ public sealed class TransformModel : ITransformModel /// public ISchema OutputSchema => _chain.Schema; - /// - /// This contains the transforms to save instantiated on an IDataView with - /// appropriate initial schema. Note that the "root" of this is typically either - /// an empty IDataView or a BinaryLoader with no rows. However, other root - /// types are possible, since we don't insist on this when loading a model - /// from a zip file. However, whenever we save, we force a BinaryLoader to - /// be serialized for the root. - /// - public IDataView View => _chain; - /// /// Create a TransformModel containing the transforms from "result" back to "input". /// diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs index cb1f13f81d..92feb35253 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs @@ -18,7 +18,7 @@ [assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand), "Save ONNX", "SaveOnnx", DocName = "command/SaveOnnx.md")] -[assembly: LoadableClass(typeof(void), typeof(SaveOnnxCommand), null, typeof(SignatureEntryPointModule), "SaveOnnxCommand")] +[assembly: LoadableClass(typeof(void), typeof(SaveOnnxCommand), null, typeof(SignatureEntryPointModule), "SaveOnnx")] namespace Microsoft.ML.Runtime.Model.Onnx { @@ -41,24 +41,24 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "The 'domain' property in the output ONNX.", NullName = "", SortOrder = 4)] public string Domain; - [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of input column names to drop", ShortName = "idrop", SortOrder = 5)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Comma delimited list of input column names to drop", ShortName = "idrop", SortOrder = 5)] public string InputsToDrop; - [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of input column names to drop", SortOrder = 6)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of input column names to drop", Name = nameof(InputsToDrop), SortOrder = 6)] public string[] InputsToDropArray; - [Argument(ArgumentType.AtMostOnce, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 7)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Comma delimited list of output column names to drop", ShortName = "odrop", SortOrder = 7)] public string OutputsToDrop; - [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", SortOrder = 8)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", Name = nameof(OutputsToDrop), SortOrder = 8)] public string[] OutputsToDropArray; - [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] public bool? LoadPredictor; - [Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] + [Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] - public IPredictorModel Model; + public ITransformModel Model; } private readonly string _outputModelPath; @@ -68,7 +68,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly bool? _loadPredictor; private readonly HashSet _inputsToDrop; private readonly HashSet _outputsToDrop; - private readonly IPredictorModel _model; + private readonly ITransformModel _model; public SaveOnnxCommand(IHostEnvironment env, Arguments args) : base(env, args, LoadName) @@ -83,20 +83,12 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args) _name = args.Name; _loadPredictor = args.LoadPredictor; - _inputsToDrop = args.InputsToDropArray != null ? CreateDropMap(args.InputsToDropArray) : CreateDropMap(args.InputsToDrop); - _outputsToDrop = args.OutputsToDropArray != null ? CreateDropMap(args.OutputsToDropArray) : CreateDropMap(args.OutputsToDrop); + _inputsToDrop = CreateDropMap(args.InputsToDropArray ?? args.InputsToDrop?.Split(',')); + _outputsToDrop = CreateDropMap(args.OutputsToDropArray ?? args.OutputsToDrop?.Split(',')); _domain = args.Domain; _model = args.Model; } - private static HashSet CreateDropMap(string toDrop) - { - if (string.IsNullOrWhiteSpace(toDrop)) - return new HashSet(); - - return new HashSet(toDrop.Split(',')); - } - private static HashSet CreateDropMap(string[] toDrop) { if (toDrop == null) @@ -140,8 +132,8 @@ private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IData private void Run(IChannel ch) { - IDataLoader loader = null; ; - IPredictor rawPred; + IDataLoader loader = null; + IPredictor rawPred = null; IDataView view; RoleMappedSchema trainSchema = null; @@ -161,12 +153,7 @@ private void Run(IChannel ch) view = loader; } else - { - view = _model.TransformModel.View; - rawPred = _model?.Predictor; - if (rawPred != null) - trainSchema = _model.GetTrainingSchema(Host); - } + view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); // Get the transform chain. IDataView source; @@ -276,7 +263,6 @@ private void Run(IChannel ch) public sealed class Output { - //REVIEW: Would be nice to include ONNX protobuf model here but code generator needs an upgrade. } //REVIEW: Ideally there is no need to define this input class and just reuse the Argument class from SaveONNX command @@ -302,12 +288,8 @@ public sealed class Input [Argument(ArgumentType.AtMostOnce, HelpText = "Array of output column names to drop", SortOrder = 6)] public string[] OutputsToDrop; - [Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 7)] - public bool? LoadPredictor; - - [Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 8)] - - public IPredictorModel Model; + [Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 7)] + public ITransformModel Model; } @@ -321,7 +303,6 @@ public static Output Apply(IHostEnvironment env, Input input) args.Domain = input.Domain; args.InputsToDropArray = input.InputsToDrop; args.OutputsToDropArray = input.OutputsToDrop; - args.LoadPredictor = input.LoadPredictor; args.Model = input.Model; var cmd = new SaveOnnxCommand(env, args); diff --git a/src/Microsoft.ML.Maml/Console.cs b/src/Microsoft.ML.Maml/Console.cs new file mode 100644 index 0000000000..7e5f0dbe9b --- /dev/null +++ b/src/Microsoft.ML.Maml/Console.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Runtime.Tools.Console +{ + public static class Console + { + public static int Main(string[] args) + { + string all = string.Join(" ", args); + return Maml.MainAll(all); + } + + public static unsafe int MainRaw(char* psz) + { + string args = new string(psz); + return Maml.MainAll(args); + } + } +} diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj index 49742313e6..961046356c 100644 --- a/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj @@ -6,11 +6,11 @@ Microsoft.ML netcoreapp2.0 Exe - Microsoft.ML.Runtime.Tools.Maml + Microsoft.ML.Runtime.Tools.Console.Console - + diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj index 75597c7c54..4dcd467b37 100644 --- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj +++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 54a65028d6..700df3f5df 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2789,15 +2789,10 @@ public sealed partial class OnnxConverter /// public string[] OutputsToDrop { get; set; } - /// - /// Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present. - /// - public bool? LoadPredictor { get; set; } - /// /// Model that needs to be converted to ONNX format. /// - public Var Model { get; set; } = new Var(); + public Var Model { get; set; } = new Var(); public sealed class Output diff --git a/src/Microsoft.ML/Microsoft.ML.csproj b/src/Microsoft.ML/Microsoft.ML.csproj index 4f3f8f0dc2..2485893912 100644 --- a/src/Microsoft.ML/Microsoft.ML.csproj +++ b/src/Microsoft.ML/Microsoft.ML.csproj @@ -14,7 +14,7 @@ - + diff --git a/src/Microsoft.ML/Models/OnnxConverter.cs b/src/Microsoft.ML/Models/OnnxConverter.cs index 420ff15a28..4bd4a03bc0 100644 --- a/src/Microsoft.ML/Models/OnnxConverter.cs +++ b/src/Microsoft.ML/Models/OnnxConverter.cs @@ -1,6 +1,9 @@ -using Microsoft.ML.Runtime; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; namespace Microsoft.ML.Models { @@ -19,7 +22,7 @@ public void Convert(PredictionModel model) Experiment experiment = environment.CreateExperiment(); experiment.Add(this); experiment.Compile(); - experiment.SetInput(Model, new PredictorModel(environment, model.PredictorModel)); + experiment.SetInput(Model, model.PredictorModel); experiment.Run(); } } diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index 6e4a6dc85f..b8dc6a9e5e 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -6,10 +6,10 @@ + - From 692b526c3c6297f27d5a66ca230767dac5169aad Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 13:51:05 -0700 Subject: [PATCH 09/22] update test baselines. --- .../Common/EntryPoints/core_manifest.json | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c0ebccf5a1..c298579fd6 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2121,24 +2121,12 @@ "IsNullable": false, "Default": null }, - { - "Name": "LoadPredictor", - "Type": "Bool", - "Desc": "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", - "Aliases": [ - "pred" - ], - "Required": false, - "SortOrder": 7.0, - "IsNullable": true, - "Default": null - }, { "Name": "Model", - "Type": "PredictorModel", + "Type": "TransformModel", "Desc": "Model that needs to be converted to ONNX format.", "Required": true, - "SortOrder": 8.0, + "SortOrder": 7.0, "IsNullable": false } ], From 6c9203393f67e0046ad2881f0fc067f9917a675b Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 14:05:23 -0700 Subject: [PATCH 10/22] cleanup. --- Microsoft.ML.sln | 26 +++++++++---------- .../Console.cs | 0 .../Microsoft.ML.Console.csproj | 2 +- .../Properties/AssemblyInfo.cs | 0 .../ChainCommand.cs | 0 .../HelpCommand.cs | 0 .../MAML.cs | 0 .../Microsoft.ML.Maml.csproj | 0 .../VersionCommand.cs | 0 .../Microsoft.ML.ResultProcessor.csproj | 2 +- src/Microsoft.ML/Microsoft.ML.csproj | 2 +- .../Microsoft.ML.TestFramework.csproj | 2 +- .../Microsoft.ML.Tests.csproj | 1 + 13 files changed, 18 insertions(+), 17 deletions(-) rename src/{Microsoft.ML.Maml => Microsoft.ML.Console}/Console.cs (100%) rename src/{Microsoft.ML.Maml => Microsoft.ML.Console}/Microsoft.ML.Console.csproj (85%) rename src/{Microsoft.ML.Commands => Microsoft.ML.Console}/Properties/AssemblyInfo.cs (100%) rename src/{Microsoft.ML.Commands => Microsoft.ML.Maml}/ChainCommand.cs (100%) rename src/{Microsoft.ML.Commands => Microsoft.ML.Maml}/HelpCommand.cs (100%) rename src/{Microsoft.ML.Commands => Microsoft.ML.Maml}/MAML.cs (100%) rename src/{Microsoft.ML.Commands => Microsoft.ML.Maml}/Microsoft.ML.Maml.csproj (100%) rename src/{Microsoft.ML.Commands => Microsoft.ML.Maml}/VersionCommand.cs (100%) diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index f4c7ae95ed..c0a2911262 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -33,8 +33,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Maml\Microsoft.ML.Console.csproj", "{D956E291-F6E5-4474-9023-91793F45ABEB}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}" @@ -107,7 +105,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Commands\Microsoft.ML.Maml.csproj", "{C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -163,10 +163,6 @@ Global {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug|Any CPU.Build.0 = Debug|Any CPU {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.ActiveCfg = Release|Any CPU {58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU - {D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.Build.0 = Debug|Any CPU - {D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.ActiveCfg = Release|Any CPU - {D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.Build.0 = Release|Any CPU {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU {2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -207,10 +203,14 @@ Global {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU - {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Debug|Any CPU.Build.0 = Debug|Any CPU - {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Release|Any CPU.ActiveCfg = Release|Any CPU - {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D}.Release|Any CPU.Build.0 = Release|Any CPU + {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.Build.0 = Release|Any CPU + {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU + {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU + {362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -228,7 +228,6 @@ Global {7288C084-11C0-43BE-AC7F-45DCFEAEEBF6} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530} - {D956E291-F6E5-4474-9023-91793F45ABEB} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {FDA2FD2C-A708-43AC-A941-4D941B0853BF} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} @@ -245,7 +244,8 @@ Global {DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} - {C5EB2982-739C-4D42-8DA5-0FB5F4223B6D} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/src/Microsoft.ML.Maml/Console.cs b/src/Microsoft.ML.Console/Console.cs similarity index 100% rename from src/Microsoft.ML.Maml/Console.cs rename to src/Microsoft.ML.Console/Console.cs diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj similarity index 85% rename from src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj rename to src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index 961046356c..339faf2224 100644 --- a/src/Microsoft.ML.Maml/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -10,8 +10,8 @@ - + \ No newline at end of file diff --git a/src/Microsoft.ML.Commands/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Console/Properties/AssemblyInfo.cs similarity index 100% rename from src/Microsoft.ML.Commands/Properties/AssemblyInfo.cs rename to src/Microsoft.ML.Console/Properties/AssemblyInfo.cs diff --git a/src/Microsoft.ML.Commands/ChainCommand.cs b/src/Microsoft.ML.Maml/ChainCommand.cs similarity index 100% rename from src/Microsoft.ML.Commands/ChainCommand.cs rename to src/Microsoft.ML.Maml/ChainCommand.cs diff --git a/src/Microsoft.ML.Commands/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs similarity index 100% rename from src/Microsoft.ML.Commands/HelpCommand.cs rename to src/Microsoft.ML.Maml/HelpCommand.cs diff --git a/src/Microsoft.ML.Commands/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs similarity index 100% rename from src/Microsoft.ML.Commands/MAML.cs rename to src/Microsoft.ML.Maml/MAML.cs diff --git a/src/Microsoft.ML.Commands/Microsoft.ML.Maml.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj similarity index 100% rename from src/Microsoft.ML.Commands/Microsoft.ML.Maml.csproj rename to src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj diff --git a/src/Microsoft.ML.Commands/VersionCommand.cs b/src/Microsoft.ML.Maml/VersionCommand.cs similarity index 100% rename from src/Microsoft.ML.Commands/VersionCommand.cs rename to src/Microsoft.ML.Maml/VersionCommand.cs diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj index 4dcd467b37..b420da0eb0 100644 --- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj +++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj @@ -8,10 +8,10 @@ - + diff --git a/src/Microsoft.ML/Microsoft.ML.csproj b/src/Microsoft.ML/Microsoft.ML.csproj index 2485893912..7b370f2804 100644 --- a/src/Microsoft.ML/Microsoft.ML.csproj +++ b/src/Microsoft.ML/Microsoft.ML.csproj @@ -14,9 +14,9 @@ - + diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index b8dc6a9e5e..d9cf8a2f29 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -6,10 +6,10 @@ - + diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index e92a2b85f0..c4ccb76ae9 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -5,6 +5,7 @@ + From e466157c0e9eec1e3d80f8a7586ebeffd86eb673 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 30 May 2018 14:12:07 -0700 Subject: [PATCH 11/22] cleanup. --- src/Microsoft.ML.Console/Properties/AssemblyInfo.cs | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 src/Microsoft.ML.Console/Properties/AssemblyInfo.cs diff --git a/src/Microsoft.ML.Console/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Console/Properties/AssemblyInfo.cs deleted file mode 100644 index 97db2a8a07..0000000000 --- a/src/Microsoft.ML.Console/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,9 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; - -[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] From 8ef9b3fbcfea266be71216edd6ca7158fa874549 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 31 May 2018 01:21:53 -0700 Subject: [PATCH 12/22] PR feedback. --- src/Microsoft.ML.Data/Commands/DataCommand.cs | 6 +- .../Model/Onnx/SaveOnnxCommand.cs | 44 +-------- src/Microsoft.ML/CSharpApi.cs | 44 ++++++++- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../Common/EntryPoints/core_manifest.json | 90 ++++++++++++++++++- .../BreastCancer/SaveModelToOnnxTest.json | 0 .../BinaryClassification.cs => OnnxTests.cs} | 14 ++- .../Scenarios/HousePricePredictionTests.cs | 2 +- 8 files changed, 147 insertions(+), 55 deletions(-) rename test/BaselineOutput/Common/{Scenario => Onnx}/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json (100%) rename test/Microsoft.ML.Tests/{Scenarios/BinaryClassification.cs => OnnxTests.cs} (88%) diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index a68dd2022a..c128aeb569 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -20,7 +20,7 @@ public static class DataCommand { public abstract class ArgumentsBase { - [Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "")] + [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "")] public SubComponent Loader; [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)] @@ -41,7 +41,7 @@ public abstract class ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "Verbose?", ShortName = "v", Hide = true)] public bool? Verbose; - [Argument(ArgumentType.AtMostOnce, HelpText = "The web server to publish the RESTful API", Hide = true)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)] public ServerChannel.IServerFactory Server; // This is actually an advisory value. The implementations themselves are responsible for @@ -51,7 +51,7 @@ public abstract class ArgumentsBase HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")] public int? Parallel; - [Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")] + [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")] public KeyValuePair>[] Transform; } diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs index 92feb35253..c4eca0abac 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs @@ -53,7 +53,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", Name = nameof(OutputsToDrop), SortOrder = 8)] public string[] OutputsToDropArray; - [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)] public bool? LoadPredictor; [Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] @@ -265,48 +265,10 @@ public sealed class Output { } - //REVIEW: Ideally there is no need to define this input class and just reuse the Argument class from SaveONNX command - //but the code generator cannot parse certain complicated data types in the base class that Argument class extends. - //We should fix the code generator and use the Argument class. - public sealed class Input - { - [Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output ONNX to.", SortOrder = 1)] - public string Onnx; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output JSON to.", SortOrder = 2)] - public string Json; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The 'name' property in the output ONNX. By default this will be the ONNX extension-less name.", NullName = "", SortOrder = 3)] - public string Name; - - [Argument(ArgumentType.AtMostOnce, HelpText = "The 'domain' property in the output ONNX.", NullName = "", SortOrder = 4)] - public string Domain; - - [Argument(ArgumentType.AtMostOnce, HelpText = "Array of input column names to drop", SortOrder = 5)] - public string[] InputsToDrop; - - [Argument(ArgumentType.AtMostOnce, HelpText = "Array of output column names to drop", SortOrder = 6)] - public string[] OutputsToDrop; - - [Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 7)] - public ITransformModel Model; - } - - [TlcModule.EntryPoint(Name = "Models.OnnxConverter", Desc = "Converts the model to ONNX format.", UserName = "ONNX Converter.")] - public static Output Apply(IHostEnvironment env, Input input) + public static Output Apply(IHostEnvironment env, Arguments input) { - Arguments args = new Arguments(); - args.Onnx = input.Onnx; - args.Json = input.Json; - args.Name = input.Name; - args.Domain = input.Domain; - args.InputsToDropArray = input.InputsToDrop; - args.OutputsToDropArray = input.OutputsToDrop; - args.Model = input.Model; - - var cmd = new SaveOnnxCommand(env, args); - cmd.Run(); + new SaveOnnxCommand(env, input).Run(); return new Output(); } diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 700df3f5df..804ae6c832 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2794,6 +2794,41 @@ public sealed partial class OnnxConverter /// public Var Model { get; set; } = new Var(); + /// + /// The data file + /// + public string DataFile { get; set; } + + /// + /// Model file to save + /// + public string OutputModelFile { get; set; } + + /// + /// Model file to load + /// + public string InputModelFile { get; set; } + + /// + /// Load transforms from model file? + /// + public bool? LoadTransforms { get; set; } + + /// + /// Random seed + /// + public int? RandomSeed { get; set; } + + /// + /// Verbose? + /// + public bool? Verbose { get; set; } + + /// + /// Desired degree of parallelism in the data pipeline + /// + public int? Parallel { get; set; } + public sealed class Output { @@ -6237,7 +6272,7 @@ public enum KMeansPlusPlusTrainerInitAlgorithm /// /// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. /// - public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { @@ -6272,6 +6307,11 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry /// public int? NumThreads { get; set; } + /// + /// Column to use for example weight + /// + public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; } + /// /// The data to be used for training /// @@ -7088,7 +7128,7 @@ namespace Trainers /// /// Train an PCA Anomaly model. /// - public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 5efa7e2013..f6888637ab 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -17,7 +17,7 @@ Models.FixedPlattCalibrator Apply a Platt calibrator with a fixed slope and offs Models.MultiOutputRegressionEvaluator Evaluates a multi output regression scored dataset. Microsoft.ML.Runtime.Data.Evaluate MultiOutputRegression Microsoft.ML.Runtime.Data.MultiOutputRegressionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.NaiveCalibrator Apply a Naive calibrator to an input model Microsoft.ML.Runtime.Internal.Calibration.Calibrate Naive Microsoft.ML.Runtime.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CalibratorOutput Models.OneVersusAll One-vs-All macro (OVA) Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro OVA Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro+Output] -Models.OnnxConverter Converts the model to ONNX format. Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand Apply Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Input Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Output +Models.OnnxConverter Converts the model to ONNX format. Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand Apply Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Arguments Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Output Models.OvaModelCombiner Combines a sequence of PredictorModels into a single model Microsoft.ML.Runtime.Learners.OvaPredictor CombineOvaModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+CombineOvaPredictorModelsInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput Models.PAVCalibrator Apply a PAV calibrator to an input model Microsoft.ML.Runtime.Internal.Calibration.Calibrate Pav Microsoft.ML.Runtime.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CalibratorOutput Models.PipelineSweeper AutoML pipeline sweeping optimzation macro. Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro PipelineSweep Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro+Output] diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c298579fd6..49770d7f56 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2061,6 +2061,18 @@ "FriendlyName": "ONNX Converter.", "ShortName": null, "Inputs": [ + { + "Name": "DataFile", + "Type": "String", + "Desc": "The data file", + "Aliases": [ + "data" + ], + "Required": false, + "SortOrder": 0.0, + "IsNullable": false, + "Default": null + }, { "Name": "Onnx", "Type": "String", @@ -2105,7 +2117,7 @@ }, "Desc": "Array of input column names to drop", "Required": false, - "SortOrder": 5.0, + "SortOrder": 6.0, "IsNullable": false, "Default": null }, @@ -2117,7 +2129,7 @@ }, "Desc": "Array of output column names to drop", "Required": false, - "SortOrder": 6.0, + "SortOrder": 8.0, "IsNullable": false, "Default": null }, @@ -2126,8 +2138,80 @@ "Type": "TransformModel", "Desc": "Model that needs to be converted to ONNX format.", "Required": true, - "SortOrder": 7.0, + "SortOrder": 10.0, "IsNullable": false + }, + { + "Name": "InputModelFile", + "Type": "String", + "Desc": "Model file to load", + "Aliases": [ + "in" + ], + "Required": false, + "SortOrder": 90.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "LoadTransforms", + "Type": "Bool", + "Desc": "Load transforms from model file?", + "Aliases": [ + "loadTrans" + ], + "Required": false, + "SortOrder": 91.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "RandomSeed", + "Type": "Int", + "Desc": "Random seed", + "Aliases": [ + "seed" + ], + "Required": false, + "SortOrder": 101.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "OutputModelFile", + "Type": "String", + "Desc": "Model file to save", + "Aliases": [ + "out" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Verbose", + "Type": "Bool", + "Desc": "Verbose?", + "Aliases": [ + "v" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "Parallel", + "Type": "Int", + "Desc": "Desired degree of parallelism in the data pipeline", + "Aliases": [ + "n" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null } ], "Outputs": [] diff --git a/test/BaselineOutput/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json similarity index 100% rename from test/BaselineOutput/Common/Scenario/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json rename to test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json diff --git a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs b/test/Microsoft.ML.Tests/OnnxTests.cs similarity index 88% rename from test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs rename to test/Microsoft.ML.Tests/OnnxTests.cs index 7d90663e39..9ccd2f1615 100644 --- a/test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -2,14 +2,20 @@ using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Trainers; using System.IO; using Xunit; +using Xunit.Abstractions; -namespace Microsoft.ML.Scenarios +namespace Microsoft.ML.Tests { - public partial class ScenariosTests + public class OnnxTests : BaseTestBaseline { + public OnnxTests(ITestOutputHelper output) : base(output) + { + } + public class BreastCancerData { public float Label; @@ -25,7 +31,7 @@ public class BreastCancerPrediction } [Fact] - public void SaveModelToOnnxTest() + public void BinaryClassificationSaveModelToOnnxTest() { string dataPath = GetDataPath(@"breast-cancer.txt"); var pipeline = new LearningPipeline(); @@ -58,7 +64,7 @@ public void SaveModelToOnnxTest() pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); var model = pipeline.Train(); - var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Scenario", "BinaryClassification", "BreastCancer"); + var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer"); var modelOutpath = GetOutputPath(subDir, "SaveModelToOnnxTest.zip"); DeleteOutputPath(modelOutpath); diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs index 08b70d7748..85e4a13eac 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs @@ -10,7 +10,7 @@ namespace Microsoft.ML.Scenarios { - public partial class ScenariosTests : BaseTestBaseline + public partial class ScenariosTests : BaseTestClass { /* A real-estate firm Contoso wants to add a house price prediction to their ASP.NET/Xamarin application. From 64cdb80938b5307786d4fa578edc6ffcd0c1fe09 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 31 May 2018 01:43:00 -0700 Subject: [PATCH 13/22] PR feedback. --- src/Microsoft.ML.Data/Commands/DataCommand.cs | 2 +- src/Microsoft.ML/CSharpApi.cs | 5 ----- .../Common/EntryPoints/core_manifest.json | 12 ------------ 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index c128aeb569..8eac40f105 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -26,7 +26,7 @@ public abstract class ArgumentsBase [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)] public string DataFile; - [Argument(ArgumentType.AtMostOnce, HelpText = "Model file to save", ShortName = "out")] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Model file to save", ShortName = "out")] public string OutputModelFile; [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)] diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 804ae6c832..ddabfeec31 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2799,11 +2799,6 @@ public sealed partial class OnnxConverter /// public string DataFile { get; set; } - /// - /// Model file to save - /// - public string OutputModelFile { get; set; } - /// /// Model file to load /// diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 49770d7f56..4a494634aa 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2177,18 +2177,6 @@ "IsNullable": true, "Default": null }, - { - "Name": "OutputModelFile", - "Type": "String", - "Desc": "Model file to save", - "Aliases": [ - "out" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": false, - "Default": null - }, { "Name": "Verbose", "Type": "Bool", From c6bc1c6925d73341d0dd73ecb6d8b9707ce965f8 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 1 Jun 2018 10:39:29 -0700 Subject: [PATCH 14/22] PR feedback. --- .../Microsoft.ML.Console.csproj | 1 + src/Microsoft.ML.Data/Commands/DataCommand.cs | 10 +- .../Model/Onnx/SaveOnnxCommand.cs | 1 - .../Onnx/OnnxMl.cs | 7596 ++++++++++------- .../Onnx/OnnxMl.md | 2 +- test/Microsoft.ML.Tests/OnnxTests.cs | 6 +- 6 files changed, 4312 insertions(+), 3304 deletions(-) diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index 339faf2224..9fe1010c2c 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -6,6 +6,7 @@ Microsoft.ML netcoreapp2.0 Exe + MML Microsoft.ML.Runtime.Tools.Console.Console diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index 8eac40f105..958dc1613a 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -29,16 +29,16 @@ public abstract class ArgumentsBase [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Model file to save", ShortName = "out")] public string OutputModelFile; - [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)] public string InputModelFile; - [Argument(ArgumentType.Multiple, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)] + [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)] public bool? LoadTransforms; - [Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)] public int? RandomSeed; - [Argument(ArgumentType.AtMostOnce, HelpText = "Verbose?", ShortName = "v", Hide = true)] + [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Verbose?", ShortName = "v", Hide = true)] public bool? Verbose; [Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)] @@ -47,7 +47,7 @@ public abstract class ArgumentsBase // This is actually an advisory value. The implementations themselves are responsible for // determining what they consider appropriate, and the actual heuristics is a bit more // complex than just this. - [Argument(ArgumentType.LastOccurenceWins, + [Argument(ArgumentType.LastOccurenceWins, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")] public int? Parallel; diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs index c4eca0abac..614f5f7a92 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs @@ -57,7 +57,6 @@ public sealed class Arguments : DataCommand.ArgumentsBase public bool? LoadPredictor; [Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] - public ITransformModel Model; } diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs index 8bf193202e..0117712614 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.cs @@ -1,5 +1,7 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: onnx-ml.proto3 +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: onnx-ml.proto3 +// #pragma warning disable 1591, 0612, 3021 #region Designer generated code @@ -7,3745 +9,4747 @@ using pbc = global::Google.Protobuf.Collections; using pbr = global::Google.Protobuf.Reflection; using scg = global::System.Collections.Generic; -namespace Microsoft.ML.Runtime.UniversalModelFormat.Onnx { - - /// Holder for reflection information generated from onnx-ml.proto3 - public static partial class OnnxMlReflection { - - #region Descriptor - /// File descriptor for onnx-ml.proto3 - public static pbr::FileDescriptor Descriptor { - get { return descriptor; } - } - private static pbr::FileDescriptor descriptor; - - static OnnxMlReflection() { - byte[] descriptorData = global::System.Convert.FromBase64String( - string.Concat( - "Cg5vbm54LW1sLnByb3RvMxIOT05OWF9OQU1FU1BBQ0Ui+wMKDkF0dHJpYnV0", - "ZVByb3RvEgwKBG5hbWUYASABKAkSEgoKZG9jX3N0cmluZxgNIAEoCRI6CgR0", - "eXBlGBQgASgOMiwuT05OWF9OQU1FU1BBQ0UuQXR0cmlidXRlUHJvdG8uQXR0", - "cmlidXRlVHlwZRIJCgFmGAIgASgCEgkKAWkYAyABKAMSCQoBcxgEIAEoDBIm", - "CgF0GAUgASgLMhsuT05OWF9OQU1FU1BBQ0UuVGVuc29yUHJvdG8SJQoBZxgG", - "IAEoCzIaLk9OTlhfTkFNRVNQQUNFLkdyYXBoUHJvdG8SDgoGZmxvYXRzGAcg", - "AygCEgwKBGludHMYCCADKAMSDwoHc3RyaW5ncxgJIAMoDBIsCgd0ZW5zb3Jz", - "GAogAygLMhsuT05OWF9OQU1FU1BBQ0UuVGVuc29yUHJvdG8SKgoGZ3JhcGhz", - "GAsgAygLMhouT05OWF9OQU1FU1BBQ0UuR3JhcGhQcm90byKRAQoNQXR0cmli", - "dXRlVHlwZRINCglVTkRFRklORUQQABIJCgVGTE9BVBABEgcKA0lOVBACEgoK", - "BlNUUklORxADEgoKBlRFTlNPUhAEEgkKBUdSQVBIEAUSCgoGRkxPQVRTEAYS", - "CAoESU5UUxAHEgsKB1NUUklOR1MQCBILCgdURU5TT1JTEAkSCgoGR1JBUEhT", - "EAoiWwoOVmFsdWVJbmZvUHJvdG8SDAoEbmFtZRgBIAEoCRInCgR0eXBlGAIg", - "ASgLMhkuT05OWF9OQU1FU1BBQ0UuVHlwZVByb3RvEhIKCmRvY19zdHJpbmcY", - "AyABKAkioAEKCU5vZGVQcm90bxINCgVpbnB1dBgBIAMoCRIOCgZvdXRwdXQY", - "AiADKAkSDAoEbmFtZRgDIAEoCRIPCgdvcF90eXBlGAQgASgJEg4KBmRvbWFp", - "bhgHIAEoCRIxCglhdHRyaWJ1dGUYBSADKAsyHi5PTk5YX05BTUVTUEFDRS5B", - "dHRyaWJ1dGVQcm90bxISCgpkb2Nfc3RyaW5nGAYgASgJIrECCgpNb2RlbFBy", - "b3RvEhIKCmlyX3ZlcnNpb24YASABKAMSOAoMb3BzZXRfaW1wb3J0GAggAygL", - "MiIuT05OWF9OQU1FU1BBQ0UuT3BlcmF0b3JTZXRJZFByb3RvEhUKDXByb2R1", - "Y2VyX25hbWUYAiABKAkSGAoQcHJvZHVjZXJfdmVyc2lvbhgDIAEoCRIOCgZk", - "b21haW4YBCABKAkSFQoNbW9kZWxfdmVyc2lvbhgFIAEoAxISCgpkb2Nfc3Ry", - "aW5nGAYgASgJEikKBWdyYXBoGAcgASgLMhouT05OWF9OQU1FU1BBQ0UuR3Jh", - "cGhQcm90bxI+Cg5tZXRhZGF0YV9wcm9wcxgOIAMoCzImLk9OTlhfTkFNRVNQ", - "QUNFLlN0cmluZ1N0cmluZ0VudHJ5UHJvdG8iNAoWU3RyaW5nU3RyaW5nRW50", - "cnlQcm90bxILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAkinAIKCkdyYXBo", - "UHJvdG8SJwoEbm9kZRgBIAMoCzIZLk9OTlhfTkFNRVNQQUNFLk5vZGVQcm90", - "bxIMCgRuYW1lGAIgASgJEjAKC2luaXRpYWxpemVyGAUgAygLMhsuT05OWF9O", - "QU1FU1BBQ0UuVGVuc29yUHJvdG8SEgoKZG9jX3N0cmluZxgKIAEoCRItCgVp", - "bnB1dBgLIAMoCzIeLk9OTlhfTkFNRVNQQUNFLlZhbHVlSW5mb1Byb3RvEi4K", - "Bm91dHB1dBgMIAMoCzIeLk9OTlhfTkFNRVNQQUNFLlZhbHVlSW5mb1Byb3Rv", - "EjIKCnZhbHVlX2luZm8YDSADKAsyHi5PTk5YX05BTUVTUEFDRS5WYWx1ZUlu", - "Zm9Qcm90byLDBAoLVGVuc29yUHJvdG8SDAoEZGltcxgBIAMoAxI3CglkYXRh", - "X3R5cGUYAiABKA4yJC5PTk5YX05BTUVTUEFDRS5UZW5zb3JQcm90by5EYXRh", - "VHlwZRI0CgdzZWdtZW50GAMgASgLMiMuT05OWF9OQU1FU1BBQ0UuVGVuc29y", - "UHJvdG8uU2VnbWVudBIWCgpmbG9hdF9kYXRhGAQgAygCQgIQARIWCgppbnQz", - "Ml9kYXRhGAUgAygFQgIQARITCgtzdHJpbmdfZGF0YRgGIAMoDBIWCgppbnQ2", - "NF9kYXRhGAcgAygDQgIQARIMCgRuYW1lGAggASgJEhIKCmRvY19zdHJpbmcY", - "DCABKAkSEAoIcmF3X2RhdGEYCSABKAwSFwoLZG91YmxlX2RhdGEYCiADKAFC", - "AhABEhcKC3VpbnQ2NF9kYXRhGAsgAygEQgIQARolCgdTZWdtZW50Eg0KBWJl", - "Z2luGAEgASgDEgsKA2VuZBgCIAEoAyLMAQoIRGF0YVR5cGUSDQoJVU5ERUZJ", - "TkVEEAASCQoFRkxPQVQQARIJCgVVSU5UOBACEggKBElOVDgQAxIKCgZVSU5U", - "MTYQBBIJCgVJTlQxNhAFEgkKBUlOVDMyEAYSCQoFSU5UNjQQBxIKCgZTVFJJ", - "TkcQCBIICgRCT09MEAkSCwoHRkxPQVQxNhAKEgoKBkRPVUJMRRALEgoKBlVJ", - "TlQzMhAMEgoKBlVJTlQ2NBANEg0KCUNPTVBMRVg2NBAOEg4KCkNPTVBMRVgx", - "MjgQDyKLAQoQVGVuc29yU2hhcGVQcm90bxI3CgNkaW0YASADKAsyKi5PTk5Y", - "X05BTUVTUEFDRS5UZW5zb3JTaGFwZVByb3RvLkRpbWVuc2lvbho+CglEaW1l", - "bnNpb24SEwoJZGltX3ZhbHVlGAEgASgDSAASEwoJZGltX3BhcmFtGAIgASgJ", - "SABCBwoFdmFsdWUi2QMKCVR5cGVQcm90bxI3Cgt0ZW5zb3JfdHlwZRgBIAEo", - "CzIgLk9OTlhfTkFNRVNQQUNFLlR5cGVQcm90by5UZW5zb3JIABI7Cg1zZXF1", - "ZW5jZV90eXBlGAQgASgLMiIuT05OWF9OQU1FU1BBQ0UuVHlwZVByb3RvLlNl", - "cXVlbmNlSAASMQoIbWFwX3R5cGUYBSABKAsyHS5PTk5YX05BTUVTUEFDRS5U", - "eXBlUHJvdG8uTWFwSAAacgoGVGVuc29yEjcKCWVsZW1fdHlwZRgBIAEoDjIk", - "Lk9OTlhfTkFNRVNQQUNFLlRlbnNvclByb3RvLkRhdGFUeXBlEi8KBXNoYXBl", - "GAIgASgLMiAuT05OWF9OQU1FU1BBQ0UuVGVuc29yU2hhcGVQcm90bxo4CghT", - "ZXF1ZW5jZRIsCgllbGVtX3R5cGUYASABKAsyGS5PTk5YX05BTUVTUEFDRS5U", - "eXBlUHJvdG8abAoDTWFwEjYKCGtleV90eXBlGAEgASgOMiQuT05OWF9OQU1F", - "U1BBQ0UuVGVuc29yUHJvdG8uRGF0YVR5cGUSLQoKdmFsdWVfdHlwZRgCIAEo", - "CzIZLk9OTlhfTkFNRVNQQUNFLlR5cGVQcm90b0IHCgV2YWx1ZSI1ChJPcGVy", - "YXRvclNldElkUHJvdG8SDgoGZG9tYWluGAEgASgJEg8KB3ZlcnNpb24YAiAB", - "KAMqYwoHVmVyc2lvbhISCg5fU1RBUlRfVkVSU0lPThAAEhkKFUlSX1ZFUlNJ", - "T05fMjAxN18xMF8xMBABEhkKFUlSX1ZFUlNJT05fMjAxN18xMF8zMBACEg4K", - "CklSX1ZFUlNJT04QA0I2qgIzTWljcm9zb2Z0Lk1hY2hpbmVMZWFybmluZy5V", - "bml2ZXJzYWxNb2RlbEZvcm1hdC5Pbm54YgZwcm90bzM=")); - descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { }, - new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.Version), }, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Parser, new[]{ "Name", "DocString", "Type", "F", "I", "S", "T", "G", "Floats", "Ints", "Strings", "Tensors", "Graphs" }, null, new[]{ typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType) }, null), +namespace Microsoft.ML.Runtime.UniversalModelFormat.Onnx +{ + + /// Holder for reflection information generated from onnx-ml.proto3 + public static partial class OnnxMlReflection + { + + #region Descriptor + /// File descriptor for onnx-ml.proto3 + public static pbr::FileDescriptor Descriptor + { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static OnnxMlReflection() + { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cg5vbm54LW1sLnByb3RvMxIEb25ueCLgAwoOQXR0cmlidXRlUHJvdG8SDAoE", + "bmFtZRgBIAEoCRIVCg1yZWZfYXR0cl9uYW1lGBUgASgJEhIKCmRvY19zdHJp", + "bmcYDSABKAkSMAoEdHlwZRgUIAEoDjIiLm9ubnguQXR0cmlidXRlUHJvdG8u", + "QXR0cmlidXRlVHlwZRIJCgFmGAIgASgCEgkKAWkYAyABKAMSCQoBcxgEIAEo", + "DBIcCgF0GAUgASgLMhEub25ueC5UZW5zb3JQcm90bxIbCgFnGAYgASgLMhAu", + "b25ueC5HcmFwaFByb3RvEg4KBmZsb2F0cxgHIAMoAhIMCgRpbnRzGAggAygD", + "Eg8KB3N0cmluZ3MYCSADKAwSIgoHdGVuc29ycxgKIAMoCzIRLm9ubnguVGVu", + "c29yUHJvdG8SIAoGZ3JhcGhzGAsgAygLMhAub25ueC5HcmFwaFByb3RvIpEB", + "Cg1BdHRyaWJ1dGVUeXBlEg0KCVVOREVGSU5FRBAAEgkKBUZMT0FUEAESBwoD", + "SU5UEAISCgoGU1RSSU5HEAMSCgoGVEVOU09SEAQSCQoFR1JBUEgQBRIKCgZG", + "TE9BVFMQBhIICgRJTlRTEAcSCwoHU1RSSU5HUxAIEgsKB1RFTlNPUlMQCRIK", + "CgZHUkFQSFMQCiJRCg5WYWx1ZUluZm9Qcm90bxIMCgRuYW1lGAEgASgJEh0K", + "BHR5cGUYAiABKAsyDy5vbm54LlR5cGVQcm90bxISCgpkb2Nfc3RyaW5nGAMg", + "ASgJIpYBCglOb2RlUHJvdG8SDQoFaW5wdXQYASADKAkSDgoGb3V0cHV0GAIg", + "AygJEgwKBG5hbWUYAyABKAkSDwoHb3BfdHlwZRgEIAEoCRIOCgZkb21haW4Y", + "ByABKAkSJwoJYXR0cmlidXRlGAUgAygLMhQub25ueC5BdHRyaWJ1dGVQcm90", + "bxISCgpkb2Nfc3RyaW5nGAYgASgJIpMCCgpNb2RlbFByb3RvEhIKCmlyX3Zl", + "cnNpb24YASABKAMSLgoMb3BzZXRfaW1wb3J0GAggAygLMhgub25ueC5PcGVy", + "YXRvclNldElkUHJvdG8SFQoNcHJvZHVjZXJfbmFtZRgCIAEoCRIYChBwcm9k", + "dWNlcl92ZXJzaW9uGAMgASgJEg4KBmRvbWFpbhgEIAEoCRIVCg1tb2RlbF92", + "ZXJzaW9uGAUgASgDEhIKCmRvY19zdHJpbmcYBiABKAkSHwoFZ3JhcGgYByAB", + "KAsyEC5vbm54LkdyYXBoUHJvdG8SNAoObWV0YWRhdGFfcHJvcHMYDiADKAsy", + "HC5vbm54LlN0cmluZ1N0cmluZ0VudHJ5UHJvdG8iNAoWU3RyaW5nU3RyaW5n", + "RW50cnlQcm90bxILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAki6gEKCkdy", + "YXBoUHJvdG8SHQoEbm9kZRgBIAMoCzIPLm9ubnguTm9kZVByb3RvEgwKBG5h", + "bWUYAiABKAkSJgoLaW5pdGlhbGl6ZXIYBSADKAsyES5vbm54LlRlbnNvclBy", + "b3RvEhIKCmRvY19zdHJpbmcYCiABKAkSIwoFaW5wdXQYCyADKAsyFC5vbm54", + "LlZhbHVlSW5mb1Byb3RvEiQKBm91dHB1dBgMIAMoCzIULm9ubnguVmFsdWVJ", + "bmZvUHJvdG8SKAoKdmFsdWVfaW5mbxgNIAMoCzIULm9ubnguVmFsdWVJbmZv", + "UHJvdG8irwQKC1RlbnNvclByb3RvEgwKBGRpbXMYASADKAMSLQoJZGF0YV90", + "eXBlGAIgASgOMhoub25ueC5UZW5zb3JQcm90by5EYXRhVHlwZRIqCgdzZWdt", + "ZW50GAMgASgLMhkub25ueC5UZW5zb3JQcm90by5TZWdtZW50EhYKCmZsb2F0", + "X2RhdGEYBCADKAJCAhABEhYKCmludDMyX2RhdGEYBSADKAVCAhABEhMKC3N0", + "cmluZ19kYXRhGAYgAygMEhYKCmludDY0X2RhdGEYByADKANCAhABEgwKBG5h", + "bWUYCCABKAkSEgoKZG9jX3N0cmluZxgMIAEoCRIQCghyYXdfZGF0YRgJIAEo", + "DBIXCgtkb3VibGVfZGF0YRgKIAMoAUICEAESFwoLdWludDY0X2RhdGEYCyAD", + "KARCAhABGiUKB1NlZ21lbnQSDQoFYmVnaW4YASABKAMSCwoDZW5kGAIgASgD", + "IswBCghEYXRhVHlwZRINCglVTkRFRklORUQQABIJCgVGTE9BVBABEgkKBVVJ", + "TlQ4EAISCAoESU5UOBADEgoKBlVJTlQxNhAEEgkKBUlOVDE2EAUSCQoFSU5U", + "MzIQBhIJCgVJTlQ2NBAHEgoKBlNUUklORxAIEggKBEJPT0wQCRILCgdGTE9B", + "VDE2EAoSCgoGRE9VQkxFEAsSCgoGVUlOVDMyEAwSCgoGVUlOVDY0EA0SDQoJ", + "Q09NUExFWDY0EA4SDgoKQ09NUExFWDEyOBAPIpUBChBUZW5zb3JTaGFwZVBy", + "b3RvEi0KA2RpbRgBIAMoCzIgLm9ubnguVGVuc29yU2hhcGVQcm90by5EaW1l", + "bnNpb24aUgoJRGltZW5zaW9uEhMKCWRpbV92YWx1ZRgBIAEoA0gAEhMKCWRp", + "bV9wYXJhbRgCIAEoCUgAEhIKCmRlbm90YXRpb24YAyABKAlCBwoFdmFsdWUi", + "nQMKCVR5cGVQcm90bxItCgt0ZW5zb3JfdHlwZRgBIAEoCzIWLm9ubnguVHlw", + "ZVByb3RvLlRlbnNvckgAEjEKDXNlcXVlbmNlX3R5cGUYBCABKAsyGC5vbm54", + "LlR5cGVQcm90by5TZXF1ZW5jZUgAEicKCG1hcF90eXBlGAUgASgLMhMub25u", + "eC5UeXBlUHJvdG8uTWFwSAASEgoKZGVub3RhdGlvbhgGIAEoCRpeCgZUZW5z", + "b3ISLQoJZWxlbV90eXBlGAEgASgOMhoub25ueC5UZW5zb3JQcm90by5EYXRh", + "VHlwZRIlCgVzaGFwZRgCIAEoCzIWLm9ubnguVGVuc29yU2hhcGVQcm90bxou", + "CghTZXF1ZW5jZRIiCgllbGVtX3R5cGUYASABKAsyDy5vbm54LlR5cGVQcm90", + "bxpYCgNNYXASLAoIa2V5X3R5cGUYASABKA4yGi5vbm54LlRlbnNvclByb3Rv", + "LkRhdGFUeXBlEiMKCnZhbHVlX3R5cGUYAiABKAsyDy5vbm54LlR5cGVQcm90", + "b0IHCgV2YWx1ZSI1ChJPcGVyYXRvclNldElkUHJvdG8SDgoGZG9tYWluGAEg", + "ASgJEg8KB3ZlcnNpb24YAiABKAMqYwoHVmVyc2lvbhISCg5fU1RBUlRfVkVS", + "U0lPThAAEhkKFUlSX1ZFUlNJT05fMjAxN18xMF8xMBABEhkKFUlSX1ZFUlNJ", + "T05fMjAxN18xMF8zMBACEg4KCklSX1ZFUlNJT04QA0IxqgIuTWljcm9zb2Z0", + "Lk1MLlJ1bnRpbWUuVW5pdmVyc2FsTW9kZWxGb3JtYXQuT25ueGIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] { typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.Version), }, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Parser, new[]{ "Name", "RefAttrName", "DocString", "Type", "F", "I", "S", "T", "G", "Floats", "Ints", "Strings", "Tensors", "Graphs" }, null, new[]{ typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType) }, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser, new[]{ "Name", "Type", "DocString" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.NodeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.NodeProto.Parser, new[]{ "Input", "Output", "Name", "OpType", "Domain", "Attribute", "DocString" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ModelProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ModelProto.Parser, new[]{ "IrVersion", "OpsetImport", "ProducerName", "ProducerVersion", "Domain", "ModelVersion", "DocString", "Graph", "MetadataProps" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.StringStringEntryProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.StringStringEntryProto.Parser, new[]{ "Key", "Value" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto.Parser, new[]{ "Node", "Name", "Initializer", "DocString", "Input", "Output", "ValueInfo" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Parser, new[]{ "Dims", "DataType", "Segment", "FloatData", "Int32Data", "StringData", "Int64Data", "Name", "DocString", "RawData", "DoubleData", "Uint64Data" }, null, new[]{ typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment.Parser, new[]{ "Begin", "End" }, null, null, null)}), - new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Parser, new[]{ "Dim" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Types.Dimension), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Types.Dimension.Parser, new[]{ "DimValue", "DimParam" }, new[]{ "Value" }, null, null)}), - new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Parser, new[]{ "TensorType", "SequenceType", "MapType" }, new[]{ "Value" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor.Parser, new[]{ "ElemType", "Shape" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Parser, new[]{ "Dim" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Types.Dimension), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Types.Dimension.Parser, new[]{ "DimValue", "DimParam", "Denotation" }, new[]{ "Value" }, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Parser, new[]{ "TensorType", "SequenceType", "MapType", "Denotation" }, new[]{ "Value" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor.Parser, new[]{ "ElemType", "Shape" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence.Parser, new[]{ "ElemType" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map.Parser, new[]{ "KeyType", "ValueType" }, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OperatorSetIdProto), global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OperatorSetIdProto.Parser, new[]{ "Domain", "Version" }, null, null, null) - })); - } - #endregion - - } - #region Enums - /// - /// To be compatible with both proto2 and proto3, we will use a version number - /// that is not defined by the default value but an explicit enum number. - /// - public enum Version { - /// - /// proto3 requires the first enum value to be zero. - /// We add this just to appease the compiler. - /// - [pbr::OriginalName("_START_VERSION")] StartVersion = 0, - /// - /// The version field is always serialized and we will use it to store the - /// version that the graph is generated from. This helps us set up version - /// control. We should use version as - /// xx(major) - xx(minor) - xxxx(bugfix) - /// and we are starting with 0x00000001 (0.0.1), which was the - /// version we published on Oct 10, 2017. - /// - [pbr::OriginalName("IR_VERSION_2017_10_10")] IrVersion20171010 = 1, - /// - /// IR_VERSION 0.0.2 published on Oct 30, 2017 - /// - Added type discriminator to AttributeProto to support proto3 users - /// - [pbr::OriginalName("IR_VERSION_2017_10_30")] IrVersion20171030 = 2, - /// - /// IR VERSION 0.0.3 published on Nov 3, 2017 - /// - For operator versioning: - /// - Added new message OperatorSetIdProto - /// - Added opset_import in ModelProto - /// - For vendor extensions, added domain in NodeProto - /// - [pbr::OriginalName("IR_VERSION")] IrVersion = 3, - } - - #endregion - - #region Messages - /// - /// A named attribute containing either singular float, integer, string - /// and tensor values, or repeated float, integer, string and tensor values. - /// An AttributeProto MUST contain the name field, and *only one* of the - /// following content fields, effectively enforcing a C/C++ union equivalent. - /// - public sealed partial class AttributeProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttributeProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[0]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public AttributeProto() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public AttributeProto(AttributeProto other) : this() { - name_ = other.name_; - docString_ = other.docString_; - type_ = other.type_; - f_ = other.f_; - i_ = other.i_; - s_ = other.s_; - T = other.t_ != null ? other.T.Clone() : null; - G = other.g_ != null ? other.G.Clone() : null; - floats_ = other.floats_.Clone(); - ints_ = other.ints_.Clone(); - strings_ = other.strings_.Clone(); - tensors_ = other.tensors_.Clone(); - graphs_ = other.graphs_.Clone(); - } + })); + } + #endregion - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public AttributeProto Clone() { - return new AttributeProto(this); } - - /// Field number for the "name" field. - public const int NameFieldNumber = 1; - private string name_ = ""; + #region Enums /// - /// The name field MUST be present for this version of the IR. + /// Versioning + /// + /// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md + /// + /// To be compatible with both proto2 and proto3, we will use a version number + /// that is not defined by the default value but an explicit enum number. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } + public enum Version + { + /// + /// proto3 requires the first enum value to be zero. + /// We add this just to appease the compiler. + /// + [pbr::OriginalName("_START_VERSION")] StartVersion = 0, + /// + /// The version field is always serialized and we will use it to store the + /// version that the graph is generated from. This helps us set up version + /// control. + /// For the IR, we are using simple numbers starting with with 0x00000001, + /// which was the version we published on Oct 10, 2017. + /// + [pbr::OriginalName("IR_VERSION_2017_10_10")] IrVersion20171010 = 1, + /// + /// IR_VERSION 2 published on Oct 30, 2017 + /// - Added type discriminator to AttributeProto to support proto3 users + /// + [pbr::OriginalName("IR_VERSION_2017_10_30")] IrVersion20171030 = 2, + /// + /// IR VERSION 3 published on Nov 3, 2017 + /// - For operator versioning: + /// - Added new message OperatorSetIdProto + /// - Added opset_import in ModelProto + /// - For vendor extensions, added domain in NodeProto + /// + [pbr::OriginalName("IR_VERSION")] IrVersion = 3, } - /// Field number for the "doc_string" field. - public const int DocStringFieldNumber = 13; - private string docString_ = ""; - /// - /// A human-readable documentation for this attribute. Markdown is allowed. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DocString { - get { return docString_; } - set { - docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + #endregion - /// Field number for the "type" field. - public const int TypeFieldNumber = 20; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType type_ = 0; + #region Messages /// - /// The type field MUST be present for this version of the IR. - /// For 0.0.1 versions of the IR, this field was not defined, and - /// implementations needed to use has_field hueristics to determine - /// which value field was in use. For IR_VERSION 0.0.2 or later, this - /// field MUST be set and match the f|i|s|t|... field in use. This - /// change was made to accomodate proto3 implementations. + /// Attributes + /// + /// A named attribute containing either singular float, integer, string, graph, + /// and tensor values, or repeated float, integer, string, graph, and tensor values. + /// An AttributeProto MUST contain the name field, and *only one* of the + /// following content fields, effectively enforcing a C/C++ union equivalent. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType Type { - get { return type_; } - set { - type_ = value; - } - } + public sealed partial class AttributeProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttributeProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - /// Field number for the "f" field. - public const int FFieldNumber = 2; - private float f_; - /// - /// Exactly ONE of the following fields must be present for this version of the IR - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public float F { - get { return f_; } - set { - f_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[0]; } + } - /// Field number for the "i" field. - public const int IFieldNumber = 3; - private long i_; - /// - /// int - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long I { - get { return i_; } - set { - i_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - /// Field number for the "s" field. - public const int SFieldNumber = 4; - private pb::ByteString s_ = pb::ByteString.Empty; - /// - /// UTF-8 string - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pb::ByteString S { - get { return s_; } - set { - s_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttributeProto() + { + OnConstruction(); + } - /// Field number for the "t" field. - public const int TFieldNumber = 5; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto t_; - /// - /// tensor value - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto T { - get { return t_; } - set { - t_ = value; - } - } + partial void OnConstruction(); - /// Field number for the "g" field. - public const int GFieldNumber = 6; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto g_; - /// - /// graph - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto G { - get { return g_; } - set { - g_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttributeProto(AttributeProto other) : this() + { + name_ = other.name_; + refAttrName_ = other.refAttrName_; + docString_ = other.docString_; + type_ = other.type_; + f_ = other.f_; + i_ = other.i_; + s_ = other.s_; + T = other.t_ != null ? other.T.Clone() : null; + G = other.g_ != null ? other.G.Clone() : null; + floats_ = other.floats_.Clone(); + ints_ = other.ints_.Clone(); + strings_ = other.strings_.Clone(); + tensors_ = other.tensors_.Clone(); + graphs_ = other.graphs_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - /// Field number for the "floats" field. - public const int FloatsFieldNumber = 7; - private static readonly pb::FieldCodec _repeated_floats_codec - = pb::FieldCodec.ForFloat(58); - private readonly pbc::RepeatedField floats_ = new pbc::RepeatedField(); - /// - /// list of floats - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Floats { - get { return floats_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttributeProto Clone() + { + return new AttributeProto(this); + } - /// Field number for the "ints" field. - public const int IntsFieldNumber = 8; - private static readonly pb::FieldCodec _repeated_ints_codec - = pb::FieldCodec.ForInt64(66); - private readonly pbc::RepeatedField ints_ = new pbc::RepeatedField(); - /// - /// list of ints - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Ints { - get { return ints_; } - } + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// The name field MUST be present for this version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name + { + get { return name_; } + set + { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - /// Field number for the "strings" field. - public const int StringsFieldNumber = 9; - private static readonly pb::FieldCodec _repeated_strings_codec - = pb::FieldCodec.ForBytes(74); - private readonly pbc::RepeatedField strings_ = new pbc::RepeatedField(); - /// - /// list of UTF-8 strings - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Strings { - get { return strings_; } - } + /// Field number for the "ref_attr_name" field. + public const int RefAttrNameFieldNumber = 21; + private string refAttrName_ = ""; + /// + /// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + /// In this case, this AttributeProto does not contain data, and it's a reference of attribute + /// in parent scope. + /// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string RefAttrName + { + get { return refAttrName_; } + set + { + refAttrName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - /// Field number for the "tensors" field. - public const int TensorsFieldNumber = 10; - private static readonly pb::FieldCodec _repeated_tensors_codec - = pb::FieldCodec.ForMessage(82, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Parser); - private readonly pbc::RepeatedField tensors_ = new pbc::RepeatedField(); - /// - /// list of tensors - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Tensors { - get { return tensors_; } - } + /// Field number for the "doc_string" field. + public const int DocStringFieldNumber = 13; + private string docString_ = ""; + /// + /// A human-readable documentation for this attribute. Markdown is allowed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DocString + { + get { return docString_; } + set + { + docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - /// Field number for the "graphs" field. - public const int GraphsFieldNumber = 11; - private static readonly pb::FieldCodec _repeated_graphs_codec - = pb::FieldCodec.ForMessage(90, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto.Parser); - private readonly pbc::RepeatedField graphs_ = new pbc::RepeatedField(); - /// - /// list of graph - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Graphs { - get { return graphs_; } - } + /// Field number for the "type" field. + public const int TypeFieldNumber = 20; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType type_ = 0; + /// + /// The type field MUST be present for this version of the IR. + /// For 0.0.1 versions of the IR, this field was not defined, and + /// implementations needed to use has_field hueristics to determine + /// which value field was in use. For IR_VERSION 0.0.2 or later, this + /// field MUST be set and match the f|i|s|t|... field in use. This + /// change was made to accomodate proto3 implementations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType Type + { + get { return type_; } + set + { + type_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as AttributeProto); - } + /// Field number for the "f" field. + public const int FFieldNumber = 2; + private float f_; + /// + /// Exactly ONE of the following fields must be present for this version of the IR + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float F + { + get { return f_; } + set + { + f_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(AttributeProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Name != other.Name) return false; - if (DocString != other.DocString) return false; - if (Type != other.Type) return false; - if (F != other.F) return false; - if (I != other.I) return false; - if (S != other.S) return false; - if (!object.Equals(T, other.T)) return false; - if (!object.Equals(G, other.G)) return false; - if(!floats_.Equals(other.floats_)) return false; - if(!ints_.Equals(other.ints_)) return false; - if(!strings_.Equals(other.strings_)) return false; - if(!tensors_.Equals(other.tensors_)) return false; - if(!graphs_.Equals(other.graphs_)) return false; - return true; - } + /// Field number for the "i" field. + public const int IFieldNumber = 3; + private long i_; + /// + /// int + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long I + { + get { return i_; } + set + { + i_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Name.Length != 0) hash ^= Name.GetHashCode(); - if (DocString.Length != 0) hash ^= DocString.GetHashCode(); - if (Type != 0) hash ^= Type.GetHashCode(); - if (F != 0F) hash ^= F.GetHashCode(); - if (I != 0L) hash ^= I.GetHashCode(); - if (S.Length != 0) hash ^= S.GetHashCode(); - if (t_ != null) hash ^= T.GetHashCode(); - if (g_ != null) hash ^= G.GetHashCode(); - hash ^= floats_.GetHashCode(); - hash ^= ints_.GetHashCode(); - hash ^= strings_.GetHashCode(); - hash ^= tensors_.GetHashCode(); - hash ^= graphs_.GetHashCode(); - return hash; - } + /// Field number for the "s" field. + public const int SFieldNumber = 4; + private pb::ByteString s_ = pb::ByteString.Empty; + /// + /// UTF-8 string + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString S + { + get { return s_; } + set + { + s_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + /// Field number for the "t" field. + public const int TFieldNumber = 5; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto t_; + /// + /// tensor value + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto T + { + get { return t_; } + set + { + t_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Name.Length != 0) { - output.WriteRawTag(10); - output.WriteString(Name); - } - if (F != 0F) { - output.WriteRawTag(21); - output.WriteFloat(F); - } - if (I != 0L) { - output.WriteRawTag(24); - output.WriteInt64(I); - } - if (S.Length != 0) { - output.WriteRawTag(34); - output.WriteBytes(S); - } - if (t_ != null) { - output.WriteRawTag(42); - output.WriteMessage(T); - } - if (g_ != null) { - output.WriteRawTag(50); - output.WriteMessage(G); - } - floats_.WriteTo(output, _repeated_floats_codec); - ints_.WriteTo(output, _repeated_ints_codec); - strings_.WriteTo(output, _repeated_strings_codec); - tensors_.WriteTo(output, _repeated_tensors_codec); - graphs_.WriteTo(output, _repeated_graphs_codec); - if (DocString.Length != 0) { - output.WriteRawTag(106); - output.WriteString(DocString); - } - if (Type != 0) { - output.WriteRawTag(160, 1); - output.WriteEnum((int) Type); - } - } + /// Field number for the "g" field. + public const int GFieldNumber = 6; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto g_; + /// + /// graph + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto G + { + get { return g_; } + set + { + g_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - if (DocString.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); - } - if (Type != 0) { - size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) Type); - } - if (F != 0F) { - size += 1 + 4; - } - if (I != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(I); - } - if (S.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeBytesSize(S); - } - if (t_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(T); - } - if (g_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(G); - } - size += floats_.CalculateSize(_repeated_floats_codec); - size += ints_.CalculateSize(_repeated_ints_codec); - size += strings_.CalculateSize(_repeated_strings_codec); - size += tensors_.CalculateSize(_repeated_tensors_codec); - size += graphs_.CalculateSize(_repeated_graphs_codec); - return size; - } + /// Field number for the "floats" field. + public const int FloatsFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_floats_codec + = pb::FieldCodec.ForFloat(58); + private readonly pbc::RepeatedField floats_ = new pbc::RepeatedField(); + /// + /// list of floats + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Floats + { + get { return floats_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(AttributeProto other) { - if (other == null) { - return; - } - if (other.Name.Length != 0) { - Name = other.Name; - } - if (other.DocString.Length != 0) { - DocString = other.DocString; - } - if (other.Type != 0) { - Type = other.Type; - } - if (other.F != 0F) { - F = other.F; - } - if (other.I != 0L) { - I = other.I; - } - if (other.S.Length != 0) { - S = other.S; - } - if (other.t_ != null) { - if (t_ == null) { - t_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto(); - } - T.MergeFrom(other.T); - } - if (other.g_ != null) { - if (g_ == null) { - g_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); - } - G.MergeFrom(other.G); - } - floats_.Add(other.floats_); - ints_.Add(other.ints_); - strings_.Add(other.strings_); - tensors_.Add(other.tensors_); - graphs_.Add(other.graphs_); - } + /// Field number for the "ints" field. + public const int IntsFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_ints_codec + = pb::FieldCodec.ForInt64(66); + private readonly pbc::RepeatedField ints_ = new pbc::RepeatedField(); + /// + /// list of ints + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Ints + { + get { return ints_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - Name = input.ReadString(); - break; - } - case 21: { - F = input.ReadFloat(); - break; - } - case 24: { - I = input.ReadInt64(); - break; - } - case 34: { - S = input.ReadBytes(); - break; - } - case 42: { - if (t_ == null) { - t_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto(); - } - input.ReadMessage(t_); - break; - } - case 50: { - if (g_ == null) { - g_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); - } - input.ReadMessage(g_); - break; - } - case 58: - case 61: { - floats_.AddEntriesFrom(input, _repeated_floats_codec); - break; - } - case 66: - case 64: { - ints_.AddEntriesFrom(input, _repeated_ints_codec); - break; - } - case 74: { - strings_.AddEntriesFrom(input, _repeated_strings_codec); - break; - } - case 82: { - tensors_.AddEntriesFrom(input, _repeated_tensors_codec); - break; - } - case 90: { - graphs_.AddEntriesFrom(input, _repeated_graphs_codec); - break; - } - case 106: { - DocString = input.ReadString(); - break; - } - case 160: { - type_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType) input.ReadEnum(); - break; - } - } - } - } + /// Field number for the "strings" field. + public const int StringsFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_strings_codec + = pb::FieldCodec.ForBytes(74); + private readonly pbc::RepeatedField strings_ = new pbc::RepeatedField(); + /// + /// list of UTF-8 strings + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Strings + { + get { return strings_; } + } - #region Nested types - /// Container for nested types declared in the AttributeProto message type. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static partial class Types { - /// - /// Note: this enum is structurally identical to the OpSchema::AttrType - /// enum defined in schema.h. If you rev one, you likely need to rev the other. - /// - public enum AttributeType { - [pbr::OriginalName("UNDEFINED")] Undefined = 0, - [pbr::OriginalName("FLOAT")] Float = 1, - [pbr::OriginalName("INT")] Int = 2, - [pbr::OriginalName("STRING")] String = 3, - [pbr::OriginalName("TENSOR")] Tensor = 4, - [pbr::OriginalName("GRAPH")] Graph = 5, - [pbr::OriginalName("FLOATS")] Floats = 6, - [pbr::OriginalName("INTS")] Ints = 7, - [pbr::OriginalName("STRINGS")] Strings = 8, - [pbr::OriginalName("TENSORS")] Tensors = 9, - [pbr::OriginalName("GRAPHS")] Graphs = 10, - } + /// Field number for the "tensors" field. + public const int TensorsFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_tensors_codec + = pb::FieldCodec.ForMessage(82, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Parser); + private readonly pbc::RepeatedField tensors_ = new pbc::RepeatedField(); + /// + /// list of tensors + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Tensors + { + get { return tensors_; } + } - } - #endregion + /// Field number for the "graphs" field. + public const int GraphsFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_graphs_codec + = pb::FieldCodec.ForMessage(90, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto.Parser); + private readonly pbc::RepeatedField graphs_ = new pbc::RepeatedField(); + /// + /// list of graph + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Graphs + { + get { return graphs_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as AttributeProto); + } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AttributeProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (Name != other.Name) return false; + if (RefAttrName != other.RefAttrName) return false; + if (DocString != other.DocString) return false; + if (Type != other.Type) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(F, other.F)) return false; + if (I != other.I) return false; + if (S != other.S) return false; + if (!object.Equals(T, other.T)) return false; + if (!object.Equals(G, other.G)) return false; + if (!floats_.Equals(other.floats_)) return false; + if (!ints_.Equals(other.ints_)) return false; + if (!strings_.Equals(other.strings_)) return false; + if (!tensors_.Equals(other.tensors_)) return false; + if (!graphs_.Equals(other.graphs_)) return false; + return Equals(_unknownFields, other._unknownFields); + } - /// - /// Defines information on value, including the name, the type, and - /// the shape of the value. - /// - public sealed partial class ValueInfoProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ValueInfoProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (RefAttrName.Length != 0) hash ^= RefAttrName.GetHashCode(); + if (DocString.Length != 0) hash ^= DocString.GetHashCode(); + if (Type != 0) hash ^= Type.GetHashCode(); + if (F != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(F); + if (I != 0L) hash ^= I.GetHashCode(); + if (S.Length != 0) hash ^= S.GetHashCode(); + if (t_ != null) hash ^= T.GetHashCode(); + if (g_ != null) hash ^= G.GetHashCode(); + hash ^= floats_.GetHashCode(); + hash ^= ints_.GetHashCode(); + hash ^= strings_.GetHashCode(); + hash ^= tensors_.GetHashCode(); + hash ^= graphs_.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[1]; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (Name.Length != 0) + { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (F != 0F) + { + output.WriteRawTag(21); + output.WriteFloat(F); + } + if (I != 0L) + { + output.WriteRawTag(24); + output.WriteInt64(I); + } + if (S.Length != 0) + { + output.WriteRawTag(34); + output.WriteBytes(S); + } + if (t_ != null) + { + output.WriteRawTag(42); + output.WriteMessage(T); + } + if (g_ != null) + { + output.WriteRawTag(50); + output.WriteMessage(G); + } + floats_.WriteTo(output, _repeated_floats_codec); + ints_.WriteTo(output, _repeated_ints_codec); + strings_.WriteTo(output, _repeated_strings_codec); + tensors_.WriteTo(output, _repeated_tensors_codec); + graphs_.WriteTo(output, _repeated_graphs_codec); + if (DocString.Length != 0) + { + output.WriteRawTag(106); + output.WriteString(DocString); + } + if (Type != 0) + { + output.WriteRawTag(160, 1); + output.WriteEnum((int)Type); + } + if (RefAttrName.Length != 0) + { + output.WriteRawTag(170, 1); + output.WriteString(RefAttrName); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ValueInfoProto() { - OnConstruction(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (Name.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (RefAttrName.Length != 0) + { + size += 2 + pb::CodedOutputStream.ComputeStringSize(RefAttrName); + } + if (DocString.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); + } + if (Type != 0) + { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int)Type); + } + if (F != 0F) + { + size += 1 + 4; + } + if (I != 0L) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(I); + } + if (S.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(S); + } + if (t_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(T); + } + if (g_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(G); + } + size += floats_.CalculateSize(_repeated_floats_codec); + size += ints_.CalculateSize(_repeated_ints_codec); + size += strings_.CalculateSize(_repeated_strings_codec); + size += tensors_.CalculateSize(_repeated_tensors_codec); + size += graphs_.CalculateSize(_repeated_graphs_codec); + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - partial void OnConstruction(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AttributeProto other) + { + if (other == null) + { + return; + } + if (other.Name.Length != 0) + { + Name = other.Name; + } + if (other.RefAttrName.Length != 0) + { + RefAttrName = other.RefAttrName; + } + if (other.DocString.Length != 0) + { + DocString = other.DocString; + } + if (other.Type != 0) + { + Type = other.Type; + } + if (other.F != 0F) + { + F = other.F; + } + if (other.I != 0L) + { + I = other.I; + } + if (other.S.Length != 0) + { + S = other.S; + } + if (other.t_ != null) + { + if (t_ == null) + { + t_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto(); + } + T.MergeFrom(other.T); + } + if (other.g_ != null) + { + if (g_ == null) + { + g_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); + } + G.MergeFrom(other.G); + } + floats_.Add(other.floats_); + ints_.Add(other.ints_); + strings_.Add(other.strings_); + tensors_.Add(other.tensors_); + graphs_.Add(other.graphs_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ValueInfoProto(ValueInfoProto other) : this() { - name_ = other.name_; - Type = other.type_ != null ? other.Type.Clone() : null; - docString_ = other.docString_; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + Name = input.ReadString(); + break; + } + case 21: + { + F = input.ReadFloat(); + break; + } + case 24: + { + I = input.ReadInt64(); + break; + } + case 34: + { + S = input.ReadBytes(); + break; + } + case 42: + { + if (t_ == null) + { + t_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto(); + } + input.ReadMessage(t_); + break; + } + case 50: + { + if (g_ == null) + { + g_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); + } + input.ReadMessage(g_); + break; + } + case 58: + case 61: + { + floats_.AddEntriesFrom(input, _repeated_floats_codec); + break; + } + case 66: + case 64: + { + ints_.AddEntriesFrom(input, _repeated_ints_codec); + break; + } + case 74: + { + strings_.AddEntriesFrom(input, _repeated_strings_codec); + break; + } + case 82: + { + tensors_.AddEntriesFrom(input, _repeated_tensors_codec); + break; + } + case 90: + { + graphs_.AddEntriesFrom(input, _repeated_graphs_codec); + break; + } + case 106: + { + DocString = input.ReadString(); + break; + } + case 160: + { + type_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Types.AttributeType)input.ReadEnum(); + break; + } + case 170: + { + RefAttrName = input.ReadString(); + break; + } + } + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ValueInfoProto Clone() { - return new ValueInfoProto(this); - } + #region Nested types + /// Container for nested types declared in the AttributeProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types + { + /// + /// Note: this enum is structurally identical to the OpSchema::AttrType + /// enum defined in schema.h. If you rev one, you likely need to rev the other. + /// + public enum AttributeType + { + [pbr::OriginalName("UNDEFINED")] Undefined = 0, + [pbr::OriginalName("FLOAT")] Float = 1, + [pbr::OriginalName("INT")] Int = 2, + [pbr::OriginalName("STRING")] String = 3, + [pbr::OriginalName("TENSOR")] Tensor = 4, + [pbr::OriginalName("GRAPH")] Graph = 5, + [pbr::OriginalName("FLOATS")] Floats = 6, + [pbr::OriginalName("INTS")] Ints = 7, + [pbr::OriginalName("STRINGS")] Strings = 8, + [pbr::OriginalName("TENSORS")] Tensors = 9, + [pbr::OriginalName("GRAPHS")] Graphs = 10, + } - /// Field number for the "name" field. - public const int NameFieldNumber = 1; - private string name_ = ""; - /// - /// This field MUST be present in this version of the IR. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + } + #endregion - /// Field number for the "type" field. - public const int TypeFieldNumber = 2; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto type_; - /// - /// This field MUST be present in this version of the IR. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto Type { - get { return type_; } - set { - type_ = value; - } } - /// Field number for the "doc_string" field. - public const int DocStringFieldNumber = 3; - private string docString_ = ""; /// - /// A human-readable documentation for this value. Markdown is allowed. + /// Defines information on value, including the name, the type, and + /// the shape of the value. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DocString { - get { return docString_; } - set { - docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + public sealed partial class ValueInfoProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ValueInfoProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as ValueInfoProto); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[1]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(ValueInfoProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Name != other.Name) return false; - if (!object.Equals(Type, other.Type)) return false; - if (DocString != other.DocString) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Name.Length != 0) hash ^= Name.GetHashCode(); - if (type_ != null) hash ^= Type.GetHashCode(); - if (DocString.Length != 0) hash ^= DocString.GetHashCode(); - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValueInfoProto() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Name.Length != 0) { - output.WriteRawTag(10); - output.WriteString(Name); - } - if (type_ != null) { - output.WriteRawTag(18); - output.WriteMessage(Type); - } - if (DocString.Length != 0) { - output.WriteRawTag(26); - output.WriteString(DocString); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValueInfoProto(ValueInfoProto other) : this() + { + name_ = other.name_; + Type = other.type_ != null ? other.Type.Clone() : null; + docString_ = other.docString_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - if (type_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(Type); - } - if (DocString.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); - } - return size; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValueInfoProto Clone() + { + return new ValueInfoProto(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(ValueInfoProto other) { - if (other == null) { - return; - } - if (other.Name.Length != 0) { - Name = other.Name; - } - if (other.type_ != null) { - if (type_ == null) { - type_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); - } - Type.MergeFrom(other.Type); - } - if (other.DocString.Length != 0) { - DocString = other.DocString; - } - } + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// This field MUST be present in this version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name + { + get { return name_; } + set + { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - Name = input.ReadString(); - break; - } - case 18: { - if (type_ == null) { - type_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); - } - input.ReadMessage(type_); - break; - } - case 26: { - DocString = input.ReadString(); - break; - } - } - } - } + /// Field number for the "type" field. + public const int TypeFieldNumber = 2; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto type_; + /// + /// This field MUST be present in this version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto Type + { + get { return type_; } + set + { + type_ = value; + } + } - } - - /// - /// NodeProto stores a node that is similar to the notion of "layer" - /// or "operator" in many deep learning frameworks. For example, it can be a - /// node of type "Conv" that takes in an image, a filter tensor and a bias - /// tensor, and produces the convolved output. - /// - public sealed partial class NodeProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[2]; } - } + /// Field number for the "doc_string" field. + public const int DocStringFieldNumber = 3; + private string docString_ = ""; + /// + /// A human-readable documentation for this value. Markdown is allowed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DocString + { + get { return docString_; } + set + { + docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as ValueInfoProto); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public NodeProto() { - OnConstruction(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ValueInfoProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (Name != other.Name) return false; + if (!object.Equals(Type, other.Type)) return false; + if (DocString != other.DocString) return false; + return Equals(_unknownFields, other._unknownFields); + } - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public NodeProto(NodeProto other) : this() { - input_ = other.input_.Clone(); - output_ = other.output_.Clone(); - name_ = other.name_; - opType_ = other.opType_; - domain_ = other.domain_; - attribute_ = other.attribute_.Clone(); - docString_ = other.docString_; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (type_ != null) hash ^= Type.GetHashCode(); + if (DocString.Length != 0) hash ^= DocString.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public NodeProto Clone() { - return new NodeProto(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - /// Field number for the "input" field. - public const int InputFieldNumber = 1; - private static readonly pb::FieldCodec _repeated_input_codec - = pb::FieldCodec.ForString(10); - private readonly pbc::RepeatedField input_ = new pbc::RepeatedField(); - /// - /// namespace Value - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Input { - get { return input_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (Name.Length != 0) + { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (type_ != null) + { + output.WriteRawTag(18); + output.WriteMessage(Type); + } + if (DocString.Length != 0) + { + output.WriteRawTag(26); + output.WriteString(DocString); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - /// Field number for the "output" field. - public const int OutputFieldNumber = 2; - private static readonly pb::FieldCodec _repeated_output_codec - = pb::FieldCodec.ForString(18); - private readonly pbc::RepeatedField output_ = new pbc::RepeatedField(); - /// - /// namespace Value - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Output { - get { return output_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (Name.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (type_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Type); + } + if (DocString.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - /// Field number for the "name" field. - public const int NameFieldNumber = 3; - private string name_ = ""; - /// - /// An optional identifier for this node in a graph. - /// This field MAY be absent in ths version of the IR. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ValueInfoProto other) + { + if (other == null) + { + return; + } + if (other.Name.Length != 0) + { + Name = other.Name; + } + if (other.type_ != null) + { + if (type_ == null) + { + type_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + } + Type.MergeFrom(other.Type); + } + if (other.DocString.Length != 0) + { + DocString = other.DocString; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - /// Field number for the "op_type" field. - public const int OpTypeFieldNumber = 4; - private string opType_ = ""; - /// - /// The symbolic identifier of the Operator to execute. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string OpType { - get { return opType_; } - set { - opType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + Name = input.ReadString(); + break; + } + case 18: + { + if (type_ == null) + { + type_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + } + input.ReadMessage(type_); + break; + } + case 26: + { + DocString = input.ReadString(); + break; + } + } + } + } - /// Field number for the "domain" field. - public const int DomainFieldNumber = 7; - private string domain_ = ""; - /// - /// The domain of the OperatorSet that specifies the operator named by op_type. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Domain { - get { return domain_; } - set { - domain_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } } - /// Field number for the "attribute" field. - public const int AttributeFieldNumber = 5; - private static readonly pb::FieldCodec _repeated_attribute_codec - = pb::FieldCodec.ForMessage(42, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Parser); - private readonly pbc::RepeatedField attribute_ = new pbc::RepeatedField(); /// - /// Additional named attributes. - /// NOTE: Simply using ValueProto.NameValuePairProto is the most general - /// solution. I kept AttributeProto to minimize churn on CI results. + /// Nodes + /// + /// Computation graphs are made up of a DAG of nodes, which represent what is + /// commonly called a "layer" or "pipeline stage" in machine learning frameworks. + /// + /// For example, it can be a node of type "Conv" that takes in an image, a filter + /// tensor and a bias tensor, and produces the convolved output. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Attribute { - get { return attribute_; } - } + public sealed partial class NodeProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - /// Field number for the "doc_string" field. - public const int DocStringFieldNumber = 6; - private string docString_ = ""; - /// - /// A human-readable documentation for this node. Markdown is allowed. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DocString { - get { return docString_; } - set { - docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[2]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as NodeProto); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(NodeProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if(!input_.Equals(other.input_)) return false; - if(!output_.Equals(other.output_)) return false; - if (Name != other.Name) return false; - if (OpType != other.OpType) return false; - if (Domain != other.Domain) return false; - if(!attribute_.Equals(other.attribute_)) return false; - if (DocString != other.DocString) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeProto() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - hash ^= input_.GetHashCode(); - hash ^= output_.GetHashCode(); - if (Name.Length != 0) hash ^= Name.GetHashCode(); - if (OpType.Length != 0) hash ^= OpType.GetHashCode(); - if (Domain.Length != 0) hash ^= Domain.GetHashCode(); - hash ^= attribute_.GetHashCode(); - if (DocString.Length != 0) hash ^= DocString.GetHashCode(); - return hash; - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeProto(NodeProto other) : this() + { + input_ = other.input_.Clone(); + output_ = other.output_.Clone(); + name_ = other.name_; + opType_ = other.opType_; + domain_ = other.domain_; + attribute_ = other.attribute_.Clone(); + docString_ = other.docString_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - input_.WriteTo(output, _repeated_input_codec); - output_.WriteTo(output, _repeated_output_codec); - if (Name.Length != 0) { - output.WriteRawTag(26); - output.WriteString(Name); - } - if (OpType.Length != 0) { - output.WriteRawTag(34); - output.WriteString(OpType); - } - attribute_.WriteTo(output, _repeated_attribute_codec); - if (DocString.Length != 0) { - output.WriteRawTag(50); - output.WriteString(DocString); - } - if (Domain.Length != 0) { - output.WriteRawTag(58); - output.WriteString(Domain); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeProto Clone() + { + return new NodeProto(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - size += input_.CalculateSize(_repeated_input_codec); - size += output_.CalculateSize(_repeated_output_codec); - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - if (OpType.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(OpType); - } - if (Domain.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Domain); - } - size += attribute_.CalculateSize(_repeated_attribute_codec); - if (DocString.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); - } - return size; - } + /// Field number for the "input" field. + public const int InputFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_input_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField input_ = new pbc::RepeatedField(); + /// + /// namespace Value + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Input + { + get { return input_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(NodeProto other) { - if (other == null) { - return; - } - input_.Add(other.input_); - output_.Add(other.output_); - if (other.Name.Length != 0) { - Name = other.Name; - } - if (other.OpType.Length != 0) { - OpType = other.OpType; - } - if (other.Domain.Length != 0) { - Domain = other.Domain; - } - attribute_.Add(other.attribute_); - if (other.DocString.Length != 0) { - DocString = other.DocString; - } - } + /// Field number for the "output" field. + public const int OutputFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_output_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField output_ = new pbc::RepeatedField(); + /// + /// namespace Value + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Output + { + get { return output_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - input_.AddEntriesFrom(input, _repeated_input_codec); - break; - } - case 18: { - output_.AddEntriesFrom(input, _repeated_output_codec); - break; - } - case 26: { - Name = input.ReadString(); - break; - } - case 34: { - OpType = input.ReadString(); - break; - } - case 42: { - attribute_.AddEntriesFrom(input, _repeated_attribute_codec); - break; - } - case 50: { - DocString = input.ReadString(); - break; - } - case 58: { - Domain = input.ReadString(); - break; - } - } - } - } + /// Field number for the "name" field. + public const int NameFieldNumber = 3; + private string name_ = ""; + /// + /// An optional identifier for this node in a graph. + /// This field MAY be absent in ths version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name + { + get { return name_; } + set + { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - } - - /// - /// ModelProto is a top-level file/container format for bundling a ML model. - /// The semantics of the model are described by the GraphProto that represents - /// a parameterized computation graph against a set of named operators that are - /// defined independently from the graph. - /// - public sealed partial class ModelProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ModelProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[3]; } - } + /// Field number for the "op_type" field. + public const int OpTypeFieldNumber = 4; + private string opType_ = ""; + /// + /// The symbolic identifier of the Operator to execute. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string OpType + { + get { return opType_; } + set + { + opType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + /// Field number for the "domain" field. + public const int DomainFieldNumber = 7; + private string domain_ = ""; + /// + /// The domain of the OperatorSet that specifies the operator named by op_type. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Domain + { + get { return domain_; } + set + { + domain_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ModelProto() { - OnConstruction(); - } + /// Field number for the "attribute" field. + public const int AttributeFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_attribute_codec + = pb::FieldCodec.ForMessage(42, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.AttributeProto.Parser); + private readonly pbc::RepeatedField attribute_ = new pbc::RepeatedField(); + /// + /// Additional named attributes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Attribute + { + get { return attribute_; } + } - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ModelProto(ModelProto other) : this() { - irVersion_ = other.irVersion_; - opsetImport_ = other.opsetImport_.Clone(); - producerName_ = other.producerName_; - producerVersion_ = other.producerVersion_; - domain_ = other.domain_; - modelVersion_ = other.modelVersion_; - docString_ = other.docString_; - Graph = other.graph_ != null ? other.Graph.Clone() : null; - metadataProps_ = other.metadataProps_.Clone(); - } + /// Field number for the "doc_string" field. + public const int DocStringFieldNumber = 6; + private string docString_ = ""; + /// + /// A human-readable documentation for this node. Markdown is allowed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DocString + { + get { return docString_; } + set + { + docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ModelProto Clone() { - return new ModelProto(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as NodeProto); + } - /// Field number for the "ir_version" field. - public const int IrVersionFieldNumber = 1; - private long irVersion_; - /// - /// The version of the IR this model targets. See Version enum above. - /// This field MUST be present. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long IrVersion { - get { return irVersion_; } - set { - irVersion_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(NodeProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (!input_.Equals(other.input_)) return false; + if (!output_.Equals(other.output_)) return false; + if (Name != other.Name) return false; + if (OpType != other.OpType) return false; + if (Domain != other.Domain) return false; + if (!attribute_.Equals(other.attribute_)) return false; + if (DocString != other.DocString) return false; + return Equals(_unknownFields, other._unknownFields); + } - /// Field number for the "opset_import" field. - public const int OpsetImportFieldNumber = 8; - private static readonly pb::FieldCodec _repeated_opsetImport_codec - = pb::FieldCodec.ForMessage(66, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OperatorSetIdProto.Parser); - private readonly pbc::RepeatedField opsetImport_ = new pbc::RepeatedField(); - /// - /// The OperatorSets this model relies on. - /// All ModelProtos MUST have at least one entry that - /// specifies which version of the ONNX OperatorSet is - /// being imported. - /// - /// All nodes in the ModelProto's graph will bind against the operator - /// with the same-domain/same-op_type operator with the HIGHEST version - /// in the referenced operator sets. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField OpsetImport { - get { return opsetImport_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + hash ^= input_.GetHashCode(); + hash ^= output_.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (OpType.Length != 0) hash ^= OpType.GetHashCode(); + if (Domain.Length != 0) hash ^= Domain.GetHashCode(); + hash ^= attribute_.GetHashCode(); + if (DocString.Length != 0) hash ^= DocString.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - /// Field number for the "producer_name" field. - public const int ProducerNameFieldNumber = 2; - private string producerName_ = ""; - /// - /// The name of the framework or tool used to generate this model. - /// This field SHOULD be present to indicate which implementation/tool/framework - /// emitted the model. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string ProducerName { - get { return producerName_; } - set { - producerName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + input_.WriteTo(output, _repeated_input_codec); + output_.WriteTo(output, _repeated_output_codec); + if (Name.Length != 0) + { + output.WriteRawTag(26); + output.WriteString(Name); + } + if (OpType.Length != 0) + { + output.WriteRawTag(34); + output.WriteString(OpType); + } + attribute_.WriteTo(output, _repeated_attribute_codec); + if (DocString.Length != 0) + { + output.WriteRawTag(50); + output.WriteString(DocString); + } + if (Domain.Length != 0) + { + output.WriteRawTag(58); + output.WriteString(Domain); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + size += input_.CalculateSize(_repeated_input_codec); + size += output_.CalculateSize(_repeated_output_codec); + if (Name.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (OpType.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(OpType); + } + if (Domain.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Domain); + } + size += attribute_.CalculateSize(_repeated_attribute_codec); + if (DocString.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(NodeProto other) + { + if (other == null) + { + return; + } + input_.Add(other.input_); + output_.Add(other.output_); + if (other.Name.Length != 0) + { + Name = other.Name; + } + if (other.OpType.Length != 0) + { + OpType = other.OpType; + } + if (other.Domain.Length != 0) + { + Domain = other.Domain; + } + attribute_.Add(other.attribute_); + if (other.DocString.Length != 0) + { + DocString = other.DocString; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + input_.AddEntriesFrom(input, _repeated_input_codec); + break; + } + case 18: + { + output_.AddEntriesFrom(input, _repeated_output_codec); + break; + } + case 26: + { + Name = input.ReadString(); + break; + } + case 34: + { + OpType = input.ReadString(); + break; + } + case 42: + { + attribute_.AddEntriesFrom(input, _repeated_attribute_codec); + break; + } + case 50: + { + DocString = input.ReadString(); + break; + } + case 58: + { + Domain = input.ReadString(); + break; + } + } + } + } - /// Field number for the "producer_version" field. - public const int ProducerVersionFieldNumber = 3; - private string producerVersion_ = ""; - /// - /// The version of the framework or tool used to generate this model. - /// This field SHOULD be present to indicate which implementation/tool/framework - /// emitted the model. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string ProducerVersion { - get { return producerVersion_; } - set { - producerVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } } - /// Field number for the "domain" field. - public const int DomainFieldNumber = 4; - private string domain_ = ""; /// - /// Domain name of the model. - /// We use reverse domain names as name space indicators. For example: - /// `com.facebook.fair` or `com.microsoft.cognitiveservices` + /// Models + /// + /// ModelProto is a top-level file/container format for bundling a ML model and + /// associating its computation graph with metadata. /// - /// Together with `model_version` and GraphProto.name, this forms the unique identity of - /// the graph. + /// The semantics of the model are described by the associated GraphProto. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Domain { - get { return domain_; } - set { - domain_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + public sealed partial class ModelProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ModelProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - /// Field number for the "model_version" field. - public const int ModelVersionFieldNumber = 5; - private long modelVersion_; - /// - /// The version of the graph encoded. See Version enum below. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long ModelVersion { - get { return modelVersion_; } - set { - modelVersion_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[3]; } + } - /// Field number for the "doc_string" field. - public const int DocStringFieldNumber = 6; - private string docString_ = ""; - /// - /// A human-readable documentation for this model. Markdown is allowed. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DocString { - get { return docString_; } - set { - docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - /// Field number for the "graph" field. - public const int GraphFieldNumber = 7; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto graph_; - /// - /// The parameterized graph that is evaluated to execute the model. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto Graph { - get { return graph_; } - set { - graph_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ModelProto() + { + OnConstruction(); + } - /// Field number for the "metadata_props" field. - public const int MetadataPropsFieldNumber = 14; - private static readonly pb::FieldCodec _repeated_metadataProps_codec - = pb::FieldCodec.ForMessage(114, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.StringStringEntryProto.Parser); - private readonly pbc::RepeatedField metadataProps_ = new pbc::RepeatedField(); - /// - /// Named metadata values; keys should be distinct. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField MetadataProps { - get { return metadataProps_; } - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as ModelProto); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ModelProto(ModelProto other) : this() + { + irVersion_ = other.irVersion_; + opsetImport_ = other.opsetImport_.Clone(); + producerName_ = other.producerName_; + producerVersion_ = other.producerVersion_; + domain_ = other.domain_; + modelVersion_ = other.modelVersion_; + docString_ = other.docString_; + Graph = other.graph_ != null ? other.Graph.Clone() : null; + metadataProps_ = other.metadataProps_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(ModelProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (IrVersion != other.IrVersion) return false; - if(!opsetImport_.Equals(other.opsetImport_)) return false; - if (ProducerName != other.ProducerName) return false; - if (ProducerVersion != other.ProducerVersion) return false; - if (Domain != other.Domain) return false; - if (ModelVersion != other.ModelVersion) return false; - if (DocString != other.DocString) return false; - if (!object.Equals(Graph, other.Graph)) return false; - if(!metadataProps_.Equals(other.metadataProps_)) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ModelProto Clone() + { + return new ModelProto(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (IrVersion != 0L) hash ^= IrVersion.GetHashCode(); - hash ^= opsetImport_.GetHashCode(); - if (ProducerName.Length != 0) hash ^= ProducerName.GetHashCode(); - if (ProducerVersion.Length != 0) hash ^= ProducerVersion.GetHashCode(); - if (Domain.Length != 0) hash ^= Domain.GetHashCode(); - if (ModelVersion != 0L) hash ^= ModelVersion.GetHashCode(); - if (DocString.Length != 0) hash ^= DocString.GetHashCode(); - if (graph_ != null) hash ^= Graph.GetHashCode(); - hash ^= metadataProps_.GetHashCode(); - return hash; - } + /// Field number for the "ir_version" field. + public const int IrVersionFieldNumber = 1; + private long irVersion_; + /// + /// The version of the IR this model targets. See Version enum above. + /// This field MUST be present. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long IrVersion + { + get { return irVersion_; } + set + { + irVersion_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + /// Field number for the "opset_import" field. + public const int OpsetImportFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_opsetImport_codec + = pb::FieldCodec.ForMessage(66, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OperatorSetIdProto.Parser); + private readonly pbc::RepeatedField opsetImport_ = new pbc::RepeatedField(); + /// + /// The OperatorSets this model relies on. + /// All ModelProtos MUST have at least one entry that + /// specifies which version of the ONNX OperatorSet is + /// being imported. + /// + /// All nodes in the ModelProto's graph will bind against the operator + /// with the same-domain/same-op_type operator with the HIGHEST version + /// in the referenced operator sets. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField OpsetImport + { + get { return opsetImport_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (IrVersion != 0L) { - output.WriteRawTag(8); - output.WriteInt64(IrVersion); - } - if (ProducerName.Length != 0) { - output.WriteRawTag(18); - output.WriteString(ProducerName); - } - if (ProducerVersion.Length != 0) { - output.WriteRawTag(26); - output.WriteString(ProducerVersion); - } - if (Domain.Length != 0) { - output.WriteRawTag(34); - output.WriteString(Domain); - } - if (ModelVersion != 0L) { - output.WriteRawTag(40); - output.WriteInt64(ModelVersion); - } - if (DocString.Length != 0) { - output.WriteRawTag(50); - output.WriteString(DocString); - } - if (graph_ != null) { - output.WriteRawTag(58); - output.WriteMessage(Graph); - } - opsetImport_.WriteTo(output, _repeated_opsetImport_codec); - metadataProps_.WriteTo(output, _repeated_metadataProps_codec); - } + /// Field number for the "producer_name" field. + public const int ProducerNameFieldNumber = 2; + private string producerName_ = ""; + /// + /// The name of the framework or tool used to generate this model. + /// This field SHOULD be present to indicate which implementation/tool/framework + /// emitted the model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ProducerName + { + get { return producerName_; } + set + { + producerName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (IrVersion != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(IrVersion); - } - size += opsetImport_.CalculateSize(_repeated_opsetImport_codec); - if (ProducerName.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(ProducerName); - } - if (ProducerVersion.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(ProducerVersion); - } - if (Domain.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Domain); - } - if (ModelVersion != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(ModelVersion); - } - if (DocString.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); - } - if (graph_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(Graph); - } - size += metadataProps_.CalculateSize(_repeated_metadataProps_codec); - return size; - } + /// Field number for the "producer_version" field. + public const int ProducerVersionFieldNumber = 3; + private string producerVersion_ = ""; + /// + /// The version of the framework or tool used to generate this model. + /// This field SHOULD be present to indicate which implementation/tool/framework + /// emitted the model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ProducerVersion + { + get { return producerVersion_; } + set + { + producerVersion_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(ModelProto other) { - if (other == null) { - return; - } - if (other.IrVersion != 0L) { - IrVersion = other.IrVersion; - } - opsetImport_.Add(other.opsetImport_); - if (other.ProducerName.Length != 0) { - ProducerName = other.ProducerName; - } - if (other.ProducerVersion.Length != 0) { - ProducerVersion = other.ProducerVersion; - } - if (other.Domain.Length != 0) { - Domain = other.Domain; - } - if (other.ModelVersion != 0L) { - ModelVersion = other.ModelVersion; - } - if (other.DocString.Length != 0) { - DocString = other.DocString; - } - if (other.graph_ != null) { - if (graph_ == null) { - graph_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); - } - Graph.MergeFrom(other.Graph); - } - metadataProps_.Add(other.metadataProps_); - } + /// Field number for the "domain" field. + public const int DomainFieldNumber = 4; + private string domain_ = ""; + /// + /// Domain name of the model. + /// We use reverse domain names as name space indicators. For example: + /// `com.facebook.fair` or `com.microsoft.cognitiveservices` + /// + /// Together with `model_version` and GraphProto.name, this forms the unique identity of + /// the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Domain + { + get { return domain_; } + set + { + domain_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - IrVersion = input.ReadInt64(); - break; - } - case 18: { - ProducerName = input.ReadString(); - break; - } - case 26: { - ProducerVersion = input.ReadString(); - break; - } - case 34: { - Domain = input.ReadString(); - break; - } - case 40: { - ModelVersion = input.ReadInt64(); - break; - } - case 50: { - DocString = input.ReadString(); - break; - } - case 58: { - if (graph_ == null) { - graph_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); - } - input.ReadMessage(graph_); - break; - } - case 66: { - opsetImport_.AddEntriesFrom(input, _repeated_opsetImport_codec); - break; - } - case 114: { - metadataProps_.AddEntriesFrom(input, _repeated_metadataProps_codec); - break; - } - } - } - } + /// Field number for the "model_version" field. + public const int ModelVersionFieldNumber = 5; + private long modelVersion_; + /// + /// The version of the graph encoded. See Version enum below. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long ModelVersion + { + get { return modelVersion_; } + set + { + modelVersion_ = value; + } + } - } + /// Field number for the "doc_string" field. + public const int DocStringFieldNumber = 6; + private string docString_ = ""; + /// + /// A human-readable documentation for this model. Markdown is allowed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DocString + { + get { return docString_; } + set + { + docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - /// - /// StringStringEntryProto follows the pattern for cross-proto-version maps. - /// See https://developers.google.com/protocol-buffers/docs/proto3#maps - /// - public sealed partial class StringStringEntryProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new StringStringEntryProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + /// Field number for the "graph" field. + public const int GraphFieldNumber = 7; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto graph_; + /// + /// The parameterized graph that is evaluated to execute the model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto Graph + { + get { return graph_; } + set + { + graph_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[4]; } - } + /// Field number for the "metadata_props" field. + public const int MetadataPropsFieldNumber = 14; + private static readonly pb::FieldCodec _repeated_metadataProps_codec + = pb::FieldCodec.ForMessage(114, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.StringStringEntryProto.Parser); + private readonly pbc::RepeatedField metadataProps_ = new pbc::RepeatedField(); + /// + /// Named metadata values; keys should be distinct. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MetadataProps + { + get { return metadataProps_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as ModelProto); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public StringStringEntryProto() { - OnConstruction(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ModelProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (IrVersion != other.IrVersion) return false; + if (!opsetImport_.Equals(other.opsetImport_)) return false; + if (ProducerName != other.ProducerName) return false; + if (ProducerVersion != other.ProducerVersion) return false; + if (Domain != other.Domain) return false; + if (ModelVersion != other.ModelVersion) return false; + if (DocString != other.DocString) return false; + if (!object.Equals(Graph, other.Graph)) return false; + if (!metadataProps_.Equals(other.metadataProps_)) return false; + return Equals(_unknownFields, other._unknownFields); + } - partial void OnConstruction(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (IrVersion != 0L) hash ^= IrVersion.GetHashCode(); + hash ^= opsetImport_.GetHashCode(); + if (ProducerName.Length != 0) hash ^= ProducerName.GetHashCode(); + if (ProducerVersion.Length != 0) hash ^= ProducerVersion.GetHashCode(); + if (Domain.Length != 0) hash ^= Domain.GetHashCode(); + if (ModelVersion != 0L) hash ^= ModelVersion.GetHashCode(); + if (DocString.Length != 0) hash ^= DocString.GetHashCode(); + if (graph_ != null) hash ^= Graph.GetHashCode(); + hash ^= metadataProps_.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public StringStringEntryProto(StringStringEntryProto other) : this() { - key_ = other.key_; - value_ = other.value_; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public StringStringEntryProto Clone() { - return new StringStringEntryProto(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (IrVersion != 0L) + { + output.WriteRawTag(8); + output.WriteInt64(IrVersion); + } + if (ProducerName.Length != 0) + { + output.WriteRawTag(18); + output.WriteString(ProducerName); + } + if (ProducerVersion.Length != 0) + { + output.WriteRawTag(26); + output.WriteString(ProducerVersion); + } + if (Domain.Length != 0) + { + output.WriteRawTag(34); + output.WriteString(Domain); + } + if (ModelVersion != 0L) + { + output.WriteRawTag(40); + output.WriteInt64(ModelVersion); + } + if (DocString.Length != 0) + { + output.WriteRawTag(50); + output.WriteString(DocString); + } + if (graph_ != null) + { + output.WriteRawTag(58); + output.WriteMessage(Graph); + } + opsetImport_.WriteTo(output, _repeated_opsetImport_codec); + metadataProps_.WriteTo(output, _repeated_metadataProps_codec); + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - /// Field number for the "key" field. - public const int KeyFieldNumber = 1; - private string key_ = ""; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Key { - get { return key_; } - set { - key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (IrVersion != 0L) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(IrVersion); + } + size += opsetImport_.CalculateSize(_repeated_opsetImport_codec); + if (ProducerName.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ProducerName); + } + if (ProducerVersion.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ProducerVersion); + } + if (Domain.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Domain); + } + if (ModelVersion != 0L) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(ModelVersion); + } + if (DocString.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); + } + if (graph_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Graph); + } + size += metadataProps_.CalculateSize(_repeated_metadataProps_codec); + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - /// Field number for the "value" field. - public const int ValueFieldNumber = 2; - private string value_ = ""; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Value { - get { return value_; } - set { - value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ModelProto other) + { + if (other == null) + { + return; + } + if (other.IrVersion != 0L) + { + IrVersion = other.IrVersion; + } + opsetImport_.Add(other.opsetImport_); + if (other.ProducerName.Length != 0) + { + ProducerName = other.ProducerName; + } + if (other.ProducerVersion.Length != 0) + { + ProducerVersion = other.ProducerVersion; + } + if (other.Domain.Length != 0) + { + Domain = other.Domain; + } + if (other.ModelVersion != 0L) + { + ModelVersion = other.ModelVersion; + } + if (other.DocString.Length != 0) + { + DocString = other.DocString; + } + if (other.graph_ != null) + { + if (graph_ == null) + { + graph_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); + } + Graph.MergeFrom(other.Graph); + } + metadataProps_.Add(other.metadataProps_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as StringStringEntryProto); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: + { + IrVersion = input.ReadInt64(); + break; + } + case 18: + { + ProducerName = input.ReadString(); + break; + } + case 26: + { + ProducerVersion = input.ReadString(); + break; + } + case 34: + { + Domain = input.ReadString(); + break; + } + case 40: + { + ModelVersion = input.ReadInt64(); + break; + } + case 50: + { + DocString = input.ReadString(); + break; + } + case 58: + { + if (graph_ == null) + { + graph_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.GraphProto(); + } + input.ReadMessage(graph_); + break; + } + case 66: + { + opsetImport_.AddEntriesFrom(input, _repeated_opsetImport_codec); + break; + } + case 114: + { + metadataProps_.AddEntriesFrom(input, _repeated_metadataProps_codec); + break; + } + } + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(StringStringEntryProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Key != other.Key) return false; - if (Value != other.Value) return false; - return true; } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Key.Length != 0) hash ^= Key.GetHashCode(); - if (Value.Length != 0) hash ^= Value.GetHashCode(); - return hash; - } + /// + /// StringStringEntryProto follows the pattern for cross-proto-version maps. + /// See https://developers.google.com/protocol-buffers/docs/proto3#maps + /// + public sealed partial class StringStringEntryProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new StringStringEntryProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[4]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Key.Length != 0) { - output.WriteRawTag(10); - output.WriteString(Key); - } - if (Value.Length != 0) { - output.WriteRawTag(18); - output.WriteString(Value); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Key.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); - } - if (Value.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Value); - } - return size; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringStringEntryProto() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(StringStringEntryProto other) { - if (other == null) { - return; - } - if (other.Key.Length != 0) { - Key = other.Key; - } - if (other.Value.Length != 0) { - Value = other.Value; - } - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - Key = input.ReadString(); - break; - } - case 18: { - Value = input.ReadString(); - break; - } - } - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringStringEntryProto(StringStringEntryProto other) : this() + { + key_ = other.key_; + value_ = other.value_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - } - - /// - /// GraphProto defines a parameterized series of nodes to form a directed acyclic graph. - /// This is the equivalent of the "network" and "graph" in many deep learning - /// frameworks. - /// - public sealed partial class GraphProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[5]; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringStringEntryProto Clone() + { + return new StringStringEntryProto(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + /// Field number for the "key" field. + public const int KeyFieldNumber = 1; + private string key_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Key + { + get { return key_; } + set + { + key_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public GraphProto() { - OnConstruction(); - } + /// Field number for the "value" field. + public const int ValueFieldNumber = 2; + private string value_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Value + { + get { return value_; } + set + { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public GraphProto(GraphProto other) : this() { - node_ = other.node_.Clone(); - name_ = other.name_; - initializer_ = other.initializer_.Clone(); - docString_ = other.docString_; - input_ = other.input_.Clone(); - output_ = other.output_.Clone(); - valueInfo_ = other.valueInfo_.Clone(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as StringStringEntryProto); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public GraphProto Clone() { - return new GraphProto(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(StringStringEntryProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (Key != other.Key) return false; + if (Value != other.Value) return false; + return Equals(_unknownFields, other._unknownFields); + } - /// Field number for the "node" field. - public const int NodeFieldNumber = 1; - private static readonly pb::FieldCodec _repeated_node_codec - = pb::FieldCodec.ForMessage(10, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.NodeProto.Parser); - private readonly pbc::RepeatedField node_ = new pbc::RepeatedField(); - /// - /// The nodes in the graph. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Node { - get { return node_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (Key.Length != 0) hash ^= Key.GetHashCode(); + if (Value.Length != 0) hash ^= Value.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - /// Field number for the "name" field. - public const int NameFieldNumber = 2; - private string name_ = ""; - /// - /// The name of the graph. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - /// Field number for the "initializer" field. - public const int InitializerFieldNumber = 5; - private static readonly pb::FieldCodec _repeated_initializer_codec - = pb::FieldCodec.ForMessage(42, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Parser); - private readonly pbc::RepeatedField initializer_ = new pbc::RepeatedField(); - /// - /// A list of named tensor values (constants), used to specify default - /// values for some of the inputs of the graph. - /// Each TensorProto entry must have a distinct name (within the list) that - /// also appears in the input list. - /// In an evaluation, the default value specified here is used if and only if - /// user specifies no value for the corresponding input parameter. - /// May be used to pass serialized parameters for networks. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Initializer { - get { return initializer_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (Key.Length != 0) + { + output.WriteRawTag(10); + output.WriteString(Key); + } + if (Value.Length != 0) + { + output.WriteRawTag(18); + output.WriteString(Value); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - /// Field number for the "doc_string" field. - public const int DocStringFieldNumber = 10; - private string docString_ = ""; - /// - /// A human-readable documentation for this graph. Markdown is allowed. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DocString { - get { return docString_; } - set { - docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (Key.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Key); + } + if (Value.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Value); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - /// Field number for the "input" field. - public const int InputFieldNumber = 11; - private static readonly pb::FieldCodec _repeated_input_codec - = pb::FieldCodec.ForMessage(90, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser); - private readonly pbc::RepeatedField input_ = new pbc::RepeatedField(); - /// - /// The inputs and outputs of the graph. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Input { - get { return input_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(StringStringEntryProto other) + { + if (other == null) + { + return; + } + if (other.Key.Length != 0) + { + Key = other.Key; + } + if (other.Value.Length != 0) + { + Value = other.Value; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + Key = input.ReadString(); + break; + } + case 18: + { + Value = input.ReadString(); + break; + } + } + } + } - /// Field number for the "output" field. - public const int OutputFieldNumber = 12; - private static readonly pb::FieldCodec _repeated_output_codec - = pb::FieldCodec.ForMessage(98, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser); - private readonly pbc::RepeatedField output_ = new pbc::RepeatedField(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Output { - get { return output_; } } - /// Field number for the "value_info" field. - public const int ValueInfoFieldNumber = 13; - private static readonly pb::FieldCodec _repeated_valueInfo_codec - = pb::FieldCodec.ForMessage(106, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser); - private readonly pbc::RepeatedField valueInfo_ = new pbc::RepeatedField(); /// - /// Information for the values in the graph. The ValueInfoProto.name's - /// must be distinct. It is optional for a value to appear in value_info list. + /// Graphs + /// + /// A graph defines the computational logic of a model and is comprised of a parameterized + /// list of nodes that form a directed acyclic graph based on their inputs and outputs. + /// This is the equivalent of the "network" or "graph" in many deep learning + /// frameworks. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField ValueInfo { - get { return valueInfo_; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as GraphProto); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(GraphProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if(!node_.Equals(other.node_)) return false; - if (Name != other.Name) return false; - if(!initializer_.Equals(other.initializer_)) return false; - if (DocString != other.DocString) return false; - if(!input_.Equals(other.input_)) return false; - if(!output_.Equals(other.output_)) return false; - if(!valueInfo_.Equals(other.valueInfo_)) return false; - return true; - } + public sealed partial class GraphProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - hash ^= node_.GetHashCode(); - if (Name.Length != 0) hash ^= Name.GetHashCode(); - hash ^= initializer_.GetHashCode(); - if (DocString.Length != 0) hash ^= DocString.GetHashCode(); - hash ^= input_.GetHashCode(); - hash ^= output_.GetHashCode(); - hash ^= valueInfo_.GetHashCode(); - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[5]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - node_.WriteTo(output, _repeated_node_codec); - if (Name.Length != 0) { - output.WriteRawTag(18); - output.WriteString(Name); - } - initializer_.WriteTo(output, _repeated_initializer_codec); - if (DocString.Length != 0) { - output.WriteRawTag(82); - output.WriteString(DocString); - } - input_.WriteTo(output, _repeated_input_codec); - output_.WriteTo(output, _repeated_output_codec); - valueInfo_.WriteTo(output, _repeated_valueInfo_codec); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphProto() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - size += node_.CalculateSize(_repeated_node_codec); - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - size += initializer_.CalculateSize(_repeated_initializer_codec); - if (DocString.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); - } - size += input_.CalculateSize(_repeated_input_codec); - size += output_.CalculateSize(_repeated_output_codec); - size += valueInfo_.CalculateSize(_repeated_valueInfo_codec); - return size; - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(GraphProto other) { - if (other == null) { - return; - } - node_.Add(other.node_); - if (other.Name.Length != 0) { - Name = other.Name; - } - initializer_.Add(other.initializer_); - if (other.DocString.Length != 0) { - DocString = other.DocString; - } - input_.Add(other.input_); - output_.Add(other.output_); - valueInfo_.Add(other.valueInfo_); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphProto(GraphProto other) : this() + { + node_ = other.node_.Clone(); + name_ = other.name_; + initializer_ = other.initializer_.Clone(); + docString_ = other.docString_; + input_ = other.input_.Clone(); + output_ = other.output_.Clone(); + valueInfo_ = other.valueInfo_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - node_.AddEntriesFrom(input, _repeated_node_codec); - break; - } - case 18: { - Name = input.ReadString(); - break; - } - case 42: { - initializer_.AddEntriesFrom(input, _repeated_initializer_codec); - break; - } - case 82: { - DocString = input.ReadString(); - break; - } - case 90: { - input_.AddEntriesFrom(input, _repeated_input_codec); - break; - } - case 98: { - output_.AddEntriesFrom(input, _repeated_output_codec); - break; - } - case 106: { - valueInfo_.AddEntriesFrom(input, _repeated_valueInfo_codec); - break; - } - } - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphProto Clone() + { + return new GraphProto(this); + } - } + /// Field number for the "node" field. + public const int NodeFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_node_codec + = pb::FieldCodec.ForMessage(10, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.NodeProto.Parser); + private readonly pbc::RepeatedField node_ = new pbc::RepeatedField(); + /// + /// The nodes in the graph, sorted topologically. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Node + { + get { return node_; } + } - /// - /// A message defined to store a tensor in its serialized format. - /// - public sealed partial class TensorProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + /// Field number for the "name" field. + public const int NameFieldNumber = 2; + private string name_ = ""; + /// + /// The name of the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name + { + get { return name_; } + set + { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[6]; } - } + /// Field number for the "initializer" field. + public const int InitializerFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_initializer_codec + = pb::FieldCodec.ForMessage(42, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Parser); + private readonly pbc::RepeatedField initializer_ = new pbc::RepeatedField(); + /// + /// A list of named tensor values, used to specify constant inputs of the graph. + /// Each TensorProto entry must have a distinct name (within the list) that + /// also appears in the input list. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Initializer + { + get { return initializer_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + /// Field number for the "doc_string" field. + public const int DocStringFieldNumber = 10; + private string docString_ = ""; + /// + /// A human-readable documentation for this graph. Markdown is allowed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DocString + { + get { return docString_; } + set + { + docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TensorProto() { - OnConstruction(); - } + /// Field number for the "input" field. + public const int InputFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_input_codec + = pb::FieldCodec.ForMessage(90, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser); + private readonly pbc::RepeatedField input_ = new pbc::RepeatedField(); + /// + /// The inputs and outputs of the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Input + { + get { return input_; } + } - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TensorProto(TensorProto other) : this() { - dims_ = other.dims_.Clone(); - dataType_ = other.dataType_; - Segment = other.segment_ != null ? other.Segment.Clone() : null; - floatData_ = other.floatData_.Clone(); - int32Data_ = other.int32Data_.Clone(); - stringData_ = other.stringData_.Clone(); - int64Data_ = other.int64Data_.Clone(); - name_ = other.name_; - docString_ = other.docString_; - rawData_ = other.rawData_; - doubleData_ = other.doubleData_.Clone(); - uint64Data_ = other.uint64Data_.Clone(); - } + /// Field number for the "output" field. + public const int OutputFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_output_codec + = pb::FieldCodec.ForMessage(98, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser); + private readonly pbc::RepeatedField output_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Output + { + get { return output_; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TensorProto Clone() { - return new TensorProto(this); - } + /// Field number for the "value_info" field. + public const int ValueInfoFieldNumber = 13; + private static readonly pb::FieldCodec _repeated_valueInfo_codec + = pb::FieldCodec.ForMessage(106, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.ValueInfoProto.Parser); + private readonly pbc::RepeatedField valueInfo_ = new pbc::RepeatedField(); + /// + /// Information for the values in the graph. The ValueInfoProto.name's + /// must be distinct. It is optional for a value to appear in value_info list. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ValueInfo + { + get { return valueInfo_; } + } - /// Field number for the "dims" field. - public const int DimsFieldNumber = 1; - private static readonly pb::FieldCodec _repeated_dims_codec - = pb::FieldCodec.ForInt64(10); - private readonly pbc::RepeatedField dims_ = new pbc::RepeatedField(); - /// - /// The shape of the tensor. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Dims { - get { return dims_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as GraphProto); + } - /// Field number for the "data_type" field. - public const int DataTypeFieldNumber = 2; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType dataType_ = 0; - /// - /// The data type of the tensor. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType DataType { - get { return dataType_; } - set { - dataType_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(GraphProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (!node_.Equals(other.node_)) return false; + if (Name != other.Name) return false; + if (!initializer_.Equals(other.initializer_)) return false; + if (DocString != other.DocString) return false; + if (!input_.Equals(other.input_)) return false; + if (!output_.Equals(other.output_)) return false; + if (!valueInfo_.Equals(other.valueInfo_)) return false; + return Equals(_unknownFields, other._unknownFields); + } - /// Field number for the "segment" field. - public const int SegmentFieldNumber = 3; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment segment_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment Segment { - get { return segment_; } - set { - segment_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + hash ^= node_.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= initializer_.GetHashCode(); + if (DocString.Length != 0) hash ^= DocString.GetHashCode(); + hash ^= input_.GetHashCode(); + hash ^= output_.GetHashCode(); + hash ^= valueInfo_.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - /// Field number for the "float_data" field. - public const int FloatDataFieldNumber = 4; - private static readonly pb::FieldCodec _repeated_floatData_codec - = pb::FieldCodec.ForFloat(34); - private readonly pbc::RepeatedField floatData_ = new pbc::RepeatedField(); - /// - /// For float and complex64 values - /// Complex64 tensors are encoded as a single array of floats, - /// with the real components appearing in odd numbered positions, - /// and the corresponding imaginary component apparing in the - /// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - /// is encoded as [1.0, 2.0 ,3.0 ,4.0] - /// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField FloatData { - get { return floatData_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - /// Field number for the "int32_data" field. - public const int Int32DataFieldNumber = 5; - private static readonly pb::FieldCodec _repeated_int32Data_codec - = pb::FieldCodec.ForInt32(42); - private readonly pbc::RepeatedField int32Data_ = new pbc::RepeatedField(); - /// - /// For int32, uint8, int8, uint16, int16, bool, and float16 values - /// float16 values must be bit-wise converted to an uint16_t prior - /// to writing to the buffer. - /// When this field is present, the data_type field MUST be - /// INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Int32Data { - get { return int32Data_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + node_.WriteTo(output, _repeated_node_codec); + if (Name.Length != 0) + { + output.WriteRawTag(18); + output.WriteString(Name); + } + initializer_.WriteTo(output, _repeated_initializer_codec); + if (DocString.Length != 0) + { + output.WriteRawTag(82); + output.WriteString(DocString); + } + input_.WriteTo(output, _repeated_input_codec); + output_.WriteTo(output, _repeated_output_codec); + valueInfo_.WriteTo(output, _repeated_valueInfo_codec); + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - /// Field number for the "string_data" field. - public const int StringDataFieldNumber = 6; - private static readonly pb::FieldCodec _repeated_stringData_codec - = pb::FieldCodec.ForBytes(50); - private readonly pbc::RepeatedField stringData_ = new pbc::RepeatedField(); - /// - /// For strings. - /// Each element of string_data is a UTF-8 encoded Unicode - /// string. No trailing null, no leading BOM. The protobuf "string" - /// scalar type is not used to match ML community conventions. - /// When this field is present, the data_type field MUST be STRING - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField StringData { - get { return stringData_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + size += node_.CalculateSize(_repeated_node_codec); + if (Name.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += initializer_.CalculateSize(_repeated_initializer_codec); + if (DocString.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); + } + size += input_.CalculateSize(_repeated_input_codec); + size += output_.CalculateSize(_repeated_output_codec); + size += valueInfo_.CalculateSize(_repeated_valueInfo_codec); + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - /// Field number for the "int64_data" field. - public const int Int64DataFieldNumber = 7; - private static readonly pb::FieldCodec _repeated_int64Data_codec - = pb::FieldCodec.ForInt64(58); - private readonly pbc::RepeatedField int64Data_ = new pbc::RepeatedField(); - /// - /// For int64. - /// When this field is present, the data_type field MUST be INT64 - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Int64Data { - get { return int64Data_; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(GraphProto other) + { + if (other == null) + { + return; + } + node_.Add(other.node_); + if (other.Name.Length != 0) + { + Name = other.Name; + } + initializer_.Add(other.initializer_); + if (other.DocString.Length != 0) + { + DocString = other.DocString; + } + input_.Add(other.input_); + output_.Add(other.output_); + valueInfo_.Add(other.valueInfo_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - /// Field number for the "name" field. - public const int NameFieldNumber = 8; - private string name_ = ""; - /// - /// Optionally, a name for the tensor. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Name { - get { return name_; } - set { - name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + node_.AddEntriesFrom(input, _repeated_node_codec); + break; + } + case 18: + { + Name = input.ReadString(); + break; + } + case 42: + { + initializer_.AddEntriesFrom(input, _repeated_initializer_codec); + break; + } + case 82: + { + DocString = input.ReadString(); + break; + } + case 90: + { + input_.AddEntriesFrom(input, _repeated_input_codec); + break; + } + case 98: + { + output_.AddEntriesFrom(input, _repeated_output_codec); + break; + } + case 106: + { + valueInfo_.AddEntriesFrom(input, _repeated_valueInfo_codec); + break; + } + } + } + } - /// Field number for the "doc_string" field. - public const int DocStringFieldNumber = 12; - private string docString_ = ""; - /// - /// A human-readable documentation for this tensor. Markdown is allowed. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DocString { - get { return docString_; } - set { - docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } } - /// Field number for the "raw_data" field. - public const int RawDataFieldNumber = 9; - private pb::ByteString rawData_ = pb::ByteString.Empty; /// - /// Serializations can either use one of the fields above, or use this - /// raw bytes field. The only exception is the string case, where one is - /// required to store the content in the repeated bytes string_data field. - /// - /// When this raw_data field is used to store tensor value, elements MUST - /// be stored in as fixed-width, little-endian order. - /// Floating-point data types MUST be stored in IEEE 754 format. - /// Complex64 elements must be written as two consecutive FLOAT values, real component first. - /// Complex128 elements must be written as two consecutive DOUBLE values, real component first. - /// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + /// Tensors /// - /// Note: the advantage of specific field rather than the raw_data field is - /// that in some cases (e.g. int data), protobuf does a better packing via - /// variable length storage, and may lead to smaller binary footprint. - /// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pb::ByteString RawData { - get { return rawData_; } - set { - rawData_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } - - /// Field number for the "double_data" field. - public const int DoubleDataFieldNumber = 10; - private static readonly pb::FieldCodec _repeated_doubleData_codec - = pb::FieldCodec.ForDouble(82); - private readonly pbc::RepeatedField doubleData_ = new pbc::RepeatedField(); - /// - /// For double - /// Complex64 tensors are encoded as a single array of doubles, - /// with the real components appearing in odd numbered positions, - /// and the corresponding imaginary component apparing in the - /// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - /// is encoded as [1.0, 2.0 ,3.0 ,4.0] - /// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + /// A serialized tensor value. /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField DoubleData { - get { return doubleData_; } - } - - /// Field number for the "uint64_data" field. - public const int Uint64DataFieldNumber = 11; - private static readonly pb::FieldCodec _repeated_uint64Data_codec - = pb::FieldCodec.ForUInt64(90); - private readonly pbc::RepeatedField uint64Data_ = new pbc::RepeatedField(); - /// - /// For uint64 and uint32 values - /// When this field is present, the data_type field MUST be - /// UINT32 or UINT64 - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Uint64Data { - get { return uint64Data_; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as TensorProto); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(TensorProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if(!dims_.Equals(other.dims_)) return false; - if (DataType != other.DataType) return false; - if (!object.Equals(Segment, other.Segment)) return false; - if(!floatData_.Equals(other.floatData_)) return false; - if(!int32Data_.Equals(other.int32Data_)) return false; - if(!stringData_.Equals(other.stringData_)) return false; - if(!int64Data_.Equals(other.int64Data_)) return false; - if (Name != other.Name) return false; - if (DocString != other.DocString) return false; - if (RawData != other.RawData) return false; - if(!doubleData_.Equals(other.doubleData_)) return false; - if(!uint64Data_.Equals(other.uint64Data_)) return false; - return true; - } + public sealed partial class TensorProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - hash ^= dims_.GetHashCode(); - if (DataType != 0) hash ^= DataType.GetHashCode(); - if (segment_ != null) hash ^= Segment.GetHashCode(); - hash ^= floatData_.GetHashCode(); - hash ^= int32Data_.GetHashCode(); - hash ^= stringData_.GetHashCode(); - hash ^= int64Data_.GetHashCode(); - if (Name.Length != 0) hash ^= Name.GetHashCode(); - if (DocString.Length != 0) hash ^= DocString.GetHashCode(); - if (RawData.Length != 0) hash ^= RawData.GetHashCode(); - hash ^= doubleData_.GetHashCode(); - hash ^= uint64Data_.GetHashCode(); - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[6]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - dims_.WriteTo(output, _repeated_dims_codec); - if (DataType != 0) { - output.WriteRawTag(16); - output.WriteEnum((int) DataType); - } - if (segment_ != null) { - output.WriteRawTag(26); - output.WriteMessage(Segment); - } - floatData_.WriteTo(output, _repeated_floatData_codec); - int32Data_.WriteTo(output, _repeated_int32Data_codec); - stringData_.WriteTo(output, _repeated_stringData_codec); - int64Data_.WriteTo(output, _repeated_int64Data_codec); - if (Name.Length != 0) { - output.WriteRawTag(66); - output.WriteString(Name); - } - if (RawData.Length != 0) { - output.WriteRawTag(74); - output.WriteBytes(RawData); - } - doubleData_.WriteTo(output, _repeated_doubleData_codec); - uint64Data_.WriteTo(output, _repeated_uint64Data_codec); - if (DocString.Length != 0) { - output.WriteRawTag(98); - output.WriteString(DocString); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorProto() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - size += dims_.CalculateSize(_repeated_dims_codec); - if (DataType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) DataType); - } - if (segment_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(Segment); - } - size += floatData_.CalculateSize(_repeated_floatData_codec); - size += int32Data_.CalculateSize(_repeated_int32Data_codec); - size += stringData_.CalculateSize(_repeated_stringData_codec); - size += int64Data_.CalculateSize(_repeated_int64Data_codec); - if (Name.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); - } - if (DocString.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); - } - if (RawData.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeBytesSize(RawData); - } - size += doubleData_.CalculateSize(_repeated_doubleData_codec); - size += uint64Data_.CalculateSize(_repeated_uint64Data_codec); - return size; - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(TensorProto other) { - if (other == null) { - return; - } - dims_.Add(other.dims_); - if (other.DataType != 0) { - DataType = other.DataType; - } - if (other.segment_ != null) { - if (segment_ == null) { - segment_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment(); - } - Segment.MergeFrom(other.Segment); - } - floatData_.Add(other.floatData_); - int32Data_.Add(other.int32Data_); - stringData_.Add(other.stringData_); - int64Data_.Add(other.int64Data_); - if (other.Name.Length != 0) { - Name = other.Name; - } - if (other.DocString.Length != 0) { - DocString = other.DocString; - } - if (other.RawData.Length != 0) { - RawData = other.RawData; - } - doubleData_.Add(other.doubleData_); - uint64Data_.Add(other.uint64Data_); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorProto(TensorProto other) : this() + { + dims_ = other.dims_.Clone(); + dataType_ = other.dataType_; + Segment = other.segment_ != null ? other.Segment.Clone() : null; + floatData_ = other.floatData_.Clone(); + int32Data_ = other.int32Data_.Clone(); + stringData_ = other.stringData_.Clone(); + int64Data_ = other.int64Data_.Clone(); + name_ = other.name_; + docString_ = other.docString_; + rawData_ = other.rawData_; + doubleData_ = other.doubleData_.Clone(); + uint64Data_ = other.uint64Data_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: - case 8: { - dims_.AddEntriesFrom(input, _repeated_dims_codec); - break; - } - case 16: { - dataType_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType) input.ReadEnum(); - break; - } - case 26: { - if (segment_ == null) { - segment_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment(); - } - input.ReadMessage(segment_); - break; - } - case 34: - case 37: { - floatData_.AddEntriesFrom(input, _repeated_floatData_codec); - break; - } - case 42: - case 40: { - int32Data_.AddEntriesFrom(input, _repeated_int32Data_codec); - break; - } - case 50: { - stringData_.AddEntriesFrom(input, _repeated_stringData_codec); - break; - } - case 58: - case 56: { - int64Data_.AddEntriesFrom(input, _repeated_int64Data_codec); - break; - } - case 66: { - Name = input.ReadString(); - break; - } - case 74: { - RawData = input.ReadBytes(); - break; - } - case 82: - case 81: { - doubleData_.AddEntriesFrom(input, _repeated_doubleData_codec); - break; - } - case 90: - case 88: { - uint64Data_.AddEntriesFrom(input, _repeated_uint64Data_codec); - break; - } - case 98: { - DocString = input.ReadString(); - break; - } - } - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorProto Clone() + { + return new TensorProto(this); + } - #region Nested types - /// Container for nested types declared in the TensorProto message type. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static partial class Types { - public enum DataType { - [pbr::OriginalName("UNDEFINED")] Undefined = 0, - /// - /// Basic types. - /// - [pbr::OriginalName("FLOAT")] Float = 1, + /// Field number for the "dims" field. + public const int DimsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_dims_codec + = pb::FieldCodec.ForInt64(10); + private readonly pbc::RepeatedField dims_ = new pbc::RepeatedField(); /// - /// uint8_t + /// The shape of the tensor. /// - [pbr::OriginalName("UINT8")] Uint8 = 2, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Dims + { + get { return dims_; } + } + + /// Field number for the "data_type" field. + public const int DataTypeFieldNumber = 2; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType dataType_ = 0; /// - /// int8_t + /// The data type of the tensor. /// - [pbr::OriginalName("INT8")] Int8 = 3, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType DataType + { + get { return dataType_; } + set + { + dataType_ = value; + } + } + + /// Field number for the "segment" field. + public const int SegmentFieldNumber = 3; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment segment_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment Segment + { + get { return segment_; } + set + { + segment_ = value; + } + } + + /// Field number for the "float_data" field. + public const int FloatDataFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_floatData_codec + = pb::FieldCodec.ForFloat(34); + private readonly pbc::RepeatedField floatData_ = new pbc::RepeatedField(); /// - /// uint16_t + /// For float and complex64 values + /// Complex64 tensors are encoded as a single array of floats, + /// with the real components appearing in odd numbered positions, + /// and the corresponding imaginary component apparing in the + /// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + /// is encoded as [1.0, 2.0 ,3.0 ,4.0] + /// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. /// - [pbr::OriginalName("UINT16")] Uint16 = 4, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField FloatData + { + get { return floatData_; } + } + + /// Field number for the "int32_data" field. + public const int Int32DataFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_int32Data_codec + = pb::FieldCodec.ForInt32(42); + private readonly pbc::RepeatedField int32Data_ = new pbc::RepeatedField(); /// - /// int16_t + /// For int32, uint8, int8, uint16, int16, bool, and float16 values + /// float16 values must be bit-wise converted to an uint16_t prior + /// to writing to the buffer. + /// When this field is present, the data_type field MUST be + /// INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT16 /// - [pbr::OriginalName("INT16")] Int16 = 5, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Int32Data + { + get { return int32Data_; } + } + + /// Field number for the "string_data" field. + public const int StringDataFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_stringData_codec + = pb::FieldCodec.ForBytes(50); + private readonly pbc::RepeatedField stringData_ = new pbc::RepeatedField(); /// - /// int32_t + /// For strings. + /// Each element of string_data is a UTF-8 encoded Unicode + /// string. No trailing null, no leading BOM. The protobuf "string" + /// scalar type is not used to match ML community conventions. + /// When this field is present, the data_type field MUST be STRING /// - [pbr::OriginalName("INT32")] Int32 = 6, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField StringData + { + get { return stringData_; } + } + + /// Field number for the "int64_data" field. + public const int Int64DataFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_int64Data_codec + = pb::FieldCodec.ForInt64(58); + private readonly pbc::RepeatedField int64Data_ = new pbc::RepeatedField(); /// - /// int64_t + /// For int64. + /// When this field is present, the data_type field MUST be INT64 /// - [pbr::OriginalName("INT64")] Int64 = 7, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Int64Data + { + get { return int64Data_; } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 8; + private string name_ = ""; /// - /// string + /// Optionally, a name for the tensor. /// - [pbr::OriginalName("STRING")] String = 8, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name + { + get { return name_; } + set + { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "doc_string" field. + public const int DocStringFieldNumber = 12; + private string docString_ = ""; /// - /// bool + /// A human-readable documentation for this tensor. Markdown is allowed. /// - [pbr::OriginalName("BOOL")] Bool = 9, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DocString + { + get { return docString_; } + set + { + docString_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "raw_data" field. + public const int RawDataFieldNumber = 9; + private pb::ByteString rawData_ = pb::ByteString.Empty; /// - /// Advanced types + /// Serializations can either use one of the fields above, or use this + /// raw bytes field. The only exception is the string case, where one is + /// required to store the content in the repeated bytes string_data field. + /// + /// When this raw_data field is used to store tensor value, elements MUST + /// be stored in as fixed-width, little-endian order. + /// Floating-point data types MUST be stored in IEEE 754 format. + /// Complex64 elements must be written as two consecutive FLOAT values, real component first. + /// Complex128 elements must be written as two consecutive DOUBLE values, real component first. + /// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + /// + /// Note: the advantage of specific field rather than the raw_data field is + /// that in some cases (e.g. int data), protobuf does a better packing via + /// variable length storage, and may lead to smaller binary footprint. + /// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED /// - [pbr::OriginalName("FLOAT16")] Float16 = 10, - [pbr::OriginalName("DOUBLE")] Double = 11, - [pbr::OriginalName("UINT32")] Uint32 = 12, - [pbr::OriginalName("UINT64")] Uint64 = 13, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString RawData + { + get { return rawData_; } + set + { + rawData_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "double_data" field. + public const int DoubleDataFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_doubleData_codec + = pb::FieldCodec.ForDouble(82); + private readonly pbc::RepeatedField doubleData_ = new pbc::RepeatedField(); /// - /// complex with float32 real and imaginary components + /// For double + /// Complex64 tensors are encoded as a single array of doubles, + /// with the real components appearing in odd numbered positions, + /// and the corresponding imaginary component apparing in the + /// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + /// is encoded as [1.0, 2.0 ,3.0 ,4.0] + /// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 /// - [pbr::OriginalName("COMPLEX64")] Complex64 = 14, + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DoubleData + { + get { return doubleData_; } + } + + /// Field number for the "uint64_data" field. + public const int Uint64DataFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_uint64Data_codec + = pb::FieldCodec.ForUInt64(90); + private readonly pbc::RepeatedField uint64Data_ = new pbc::RepeatedField(); /// - /// complex with float64 real and imaginary components + /// For uint64 and uint32 values + /// When this field is present, the data_type field MUST be + /// UINT32 or UINT64 /// - [pbr::OriginalName("COMPLEX128")] Complex128 = 15, - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Uint64Data + { + get { return uint64Data_; } + } - /// - /// For very large tensors, we may want to store them in chunks, in which - /// case the following fields will specify the segment that is stored in - /// the current TensorProto. - /// - public sealed partial class Segment : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Segment()); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + public override bool Equals(object other) + { + return Equals(other as TensorProto); + } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Descriptor.NestedTypes[0]; } + public bool Equals(TensorProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (!dims_.Equals(other.dims_)) return false; + if (DataType != other.DataType) return false; + if (!object.Equals(Segment, other.Segment)) return false; + if (!floatData_.Equals(other.floatData_)) return false; + if (!int32Data_.Equals(other.int32Data_)) return false; + if (!stringData_.Equals(other.stringData_)) return false; + if (!int64Data_.Equals(other.int64Data_)) return false; + if (Name != other.Name) return false; + if (DocString != other.DocString) return false; + if (RawData != other.RawData) return false; + if (!doubleData_.Equals(other.doubleData_)) return false; + if (!uint64Data_.Equals(other.uint64Data_)) return false; + return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } + public override int GetHashCode() + { + int hash = 1; + hash ^= dims_.GetHashCode(); + if (DataType != 0) hash ^= DataType.GetHashCode(); + if (segment_ != null) hash ^= Segment.GetHashCode(); + hash ^= floatData_.GetHashCode(); + hash ^= int32Data_.GetHashCode(); + hash ^= stringData_.GetHashCode(); + hash ^= int64Data_.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (DocString.Length != 0) hash ^= DocString.GetHashCode(); + if (RawData.Length != 0) hash ^= RawData.GetHashCode(); + hash ^= doubleData_.GetHashCode(); + hash ^= uint64Data_.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Segment() { - OnConstruction(); + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); } - partial void OnConstruction(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + dims_.WriteTo(output, _repeated_dims_codec); + if (DataType != 0) + { + output.WriteRawTag(16); + output.WriteEnum((int)DataType); + } + if (segment_ != null) + { + output.WriteRawTag(26); + output.WriteMessage(Segment); + } + floatData_.WriteTo(output, _repeated_floatData_codec); + int32Data_.WriteTo(output, _repeated_int32Data_codec); + stringData_.WriteTo(output, _repeated_stringData_codec); + int64Data_.WriteTo(output, _repeated_int64Data_codec); + if (Name.Length != 0) + { + output.WriteRawTag(66); + output.WriteString(Name); + } + if (RawData.Length != 0) + { + output.WriteRawTag(74); + output.WriteBytes(RawData); + } + doubleData_.WriteTo(output, _repeated_doubleData_codec); + uint64Data_.WriteTo(output, _repeated_uint64Data_codec); + if (DocString.Length != 0) + { + output.WriteRawTag(98); + output.WriteString(DocString); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + size += dims_.CalculateSize(_repeated_dims_codec); + if (DataType != 0) + { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int)DataType); + } + if (segment_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Segment); + } + size += floatData_.CalculateSize(_repeated_floatData_codec); + size += int32Data_.CalculateSize(_repeated_int32Data_codec); + size += stringData_.CalculateSize(_repeated_stringData_codec); + size += int64Data_.CalculateSize(_repeated_int64Data_codec); + if (Name.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (DocString.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DocString); + } + if (RawData.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(RawData); + } + size += doubleData_.CalculateSize(_repeated_doubleData_codec); + size += uint64Data_.CalculateSize(_repeated_uint64Data_codec); + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TensorProto other) + { + if (other == null) + { + return; + } + dims_.Add(other.dims_); + if (other.DataType != 0) + { + DataType = other.DataType; + } + if (other.segment_ != null) + { + if (segment_ == null) + { + segment_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment(); + } + Segment.MergeFrom(other.Segment); + } + floatData_.Add(other.floatData_); + int32Data_.Add(other.int32Data_); + stringData_.Add(other.stringData_); + int64Data_.Add(other.int64Data_); + if (other.Name.Length != 0) + { + Name = other.Name; + } + if (other.DocString.Length != 0) + { + DocString = other.DocString; + } + if (other.RawData.Length != 0) + { + RawData = other.RawData; + } + doubleData_.Add(other.doubleData_); + uint64Data_.Add(other.uint64Data_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: + { + dims_.AddEntriesFrom(input, _repeated_dims_codec); + break; + } + case 16: + { + dataType_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType)input.ReadEnum(); + break; + } + case 26: + { + if (segment_ == null) + { + segment_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.Segment(); + } + input.ReadMessage(segment_); + break; + } + case 34: + case 37: + { + floatData_.AddEntriesFrom(input, _repeated_floatData_codec); + break; + } + case 42: + case 40: + { + int32Data_.AddEntriesFrom(input, _repeated_int32Data_codec); + break; + } + case 50: + { + stringData_.AddEntriesFrom(input, _repeated_stringData_codec); + break; + } + case 58: + case 56: + { + int64Data_.AddEntriesFrom(input, _repeated_int64Data_codec); + break; + } + case 66: + { + Name = input.ReadString(); + break; + } + case 74: + { + RawData = input.ReadBytes(); + break; + } + case 82: + case 81: + { + doubleData_.AddEntriesFrom(input, _repeated_doubleData_codec); + break; + } + case 90: + case 88: + { + uint64Data_.AddEntriesFrom(input, _repeated_uint64Data_codec); + break; + } + case 98: + { + DocString = input.ReadString(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the TensorProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types + { + public enum DataType + { + [pbr::OriginalName("UNDEFINED")] Undefined = 0, + /// + /// Basic types. + /// + [pbr::OriginalName("FLOAT")] Float = 1, + /// + /// uint8_t + /// + [pbr::OriginalName("UINT8")] Uint8 = 2, + /// + /// int8_t + /// + [pbr::OriginalName("INT8")] Int8 = 3, + /// + /// uint16_t + /// + [pbr::OriginalName("UINT16")] Uint16 = 4, + /// + /// int16_t + /// + [pbr::OriginalName("INT16")] Int16 = 5, + /// + /// int32_t + /// + [pbr::OriginalName("INT32")] Int32 = 6, + /// + /// int64_t + /// + [pbr::OriginalName("INT64")] Int64 = 7, + /// + /// string + /// + [pbr::OriginalName("STRING")] String = 8, + /// + /// bool + /// + [pbr::OriginalName("BOOL")] Bool = 9, + /// + /// Advanced types + /// + [pbr::OriginalName("FLOAT16")] Float16 = 10, + [pbr::OriginalName("DOUBLE")] Double = 11, + [pbr::OriginalName("UINT32")] Uint32 = 12, + [pbr::OriginalName("UINT64")] Uint64 = 13, + /// + /// complex with float32 real and imaginary components + /// + [pbr::OriginalName("COMPLEX64")] Complex64 = 14, + /// + /// complex with float64 real and imaginary components + /// + [pbr::OriginalName("COMPLEX128")] Complex128 = 15, + } + + /// + /// For very large tensors, we may want to store them in chunks, in which + /// case the following fields will specify the segment that is stored in + /// the current TensorProto. + /// + public sealed partial class Segment : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Segment()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Segment() + { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Segment(Segment other) : this() + { + begin_ = other.begin_; + end_ = other.end_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Segment Clone() + { + return new Segment(this); + } + + /// Field number for the "begin" field. + public const int BeginFieldNumber = 1; + private long begin_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long Begin + { + get { return begin_; } + set + { + begin_ = value; + } + } + + /// Field number for the "end" field. + public const int EndFieldNumber = 2; + private long end_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long End + { + get { return end_; } + set + { + end_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as Segment); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Segment other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (Begin != other.Begin) return false; + if (End != other.End) return false; + return Equals(_unknownFields, other._unknownFields); + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (Begin != 0L) hash ^= Begin.GetHashCode(); + if (End != 0L) hash ^= End.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (Begin != 0L) + { + output.WriteRawTag(8); + output.WriteInt64(Begin); + } + if (End != 0L) + { + output.WriteRawTag(16); + output.WriteInt64(End); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (Begin != 0L) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Begin); + } + if (End != 0L) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(End); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Segment other) + { + if (other == null) + { + return; + } + if (other.Begin != 0L) + { + Begin = other.Begin; + } + if (other.End != 0L) + { + End = other.End; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: + { + Begin = input.ReadInt64(); + break; + } + case 16: + { + End = input.ReadInt64(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Defines a tensor shape. A dimension can be either an integer value + /// or a symbolic variable. A symbolic variable represents an unknown + /// dimension. + /// + public sealed partial class TensorShapeProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorShapeProto()); + private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Segment(Segment other) : this() { - begin_ = other.begin_; - end_ = other.end_; + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[7]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Segment Clone() { - return new Segment(this); + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } } - /// Field number for the "begin" field. - public const int BeginFieldNumber = 1; - private long begin_; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long Begin { - get { return begin_; } - set { - begin_ = value; - } + public TensorShapeProto() + { + OnConstruction(); } - /// Field number for the "end" field. - public const int EndFieldNumber = 2; - private long end_; + partial void OnConstruction(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long End { - get { return end_; } - set { - end_ = value; - } + public TensorShapeProto(TensorShapeProto other) : this() + { + dim_ = other.dim_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as Segment); + public TensorShapeProto Clone() + { + return new TensorShapeProto(this); } + /// Field number for the "dim" field. + public const int DimFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_dim_codec + = pb::FieldCodec.ForMessage(10, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Types.Dimension.Parser); + private readonly pbc::RepeatedField dim_ = new pbc::RepeatedField(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(Segment other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Begin != other.Begin) return false; - if (End != other.End) return false; - return true; + public pbc::RepeatedField Dim + { + get { return dim_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Begin != 0L) hash ^= Begin.GetHashCode(); - if (End != 0L) hash ^= End.GetHashCode(); - return hash; + public override bool Equals(object other) + { + return Equals(other as TensorShapeProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); + public bool Equals(TensorShapeProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (!dim_.Equals(other.dim_)) return false; + return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Begin != 0L) { - output.WriteRawTag(8); - output.WriteInt64(Begin); - } - if (End != 0L) { - output.WriteRawTag(16); - output.WriteInt64(End); - } + public override int GetHashCode() + { + int hash = 1; + hash ^= dim_.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Begin != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(Begin); - } - if (End != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(End); - } - return size; + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(Segment other) { - if (other == null) { - return; - } - if (other.Begin != 0L) { - Begin = other.Begin; - } - if (other.End != 0L) { - End = other.End; - } + public void WriteTo(pb::CodedOutputStream output) + { + dim_.WriteTo(output, _repeated_dim_codec); + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - Begin = input.ReadInt64(); - break; - } - case 16: { - End = input.ReadInt64(); - break; - } + public int CalculateSize() + { + int size = 0; + size += dim_.CalculateSize(_repeated_dim_codec); + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); } - } + return size; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TensorShapeProto other) + { + if (other == null) + { + return; + } + dim_.Add(other.dim_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - } - #endregion + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + dim_.AddEntriesFrom(input, _repeated_dim_codec); + break; + } + } + } + } - } - - /// - /// Defines a tensor shape. A dimension can be either an integer value - /// or a symbolic variable. A symbolic variable represents an unknown - /// dimension. - /// - public sealed partial class TensorShapeProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorShapeProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[7]; } - } + #region Nested types + /// Container for nested types declared in the TensorShapeProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types + { + public sealed partial class Dimension : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Dimension()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Descriptor.NestedTypes[0]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TensorShapeProto() { - OnConstruction(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Dimension() + { + OnConstruction(); + } - partial void OnConstruction(); + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Dimension(Dimension other) : this() + { + denotation_ = other.denotation_; + switch (other.ValueCase) + { + case ValueOneofCase.DimValue: + DimValue = other.DimValue; + break; + case ValueOneofCase.DimParam: + DimParam = other.DimParam; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TensorShapeProto(TensorShapeProto other) : this() { - dim_ = other.dim_.Clone(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Dimension Clone() + { + return new Dimension(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TensorShapeProto Clone() { - return new TensorShapeProto(this); - } + /// Field number for the "dim_value" field. + public const int DimValueFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long DimValue + { + get { return valueCase_ == ValueOneofCase.DimValue ? (long)value_ : 0L; } + set + { + value_ = value; + valueCase_ = ValueOneofCase.DimValue; + } + } - /// Field number for the "dim" field. - public const int DimFieldNumber = 1; - private static readonly pb::FieldCodec _repeated_dim_codec - = pb::FieldCodec.ForMessage(10, global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Types.Dimension.Parser); - private readonly pbc::RepeatedField dim_ = new pbc::RepeatedField(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Dim { - get { return dim_; } - } + /// Field number for the "dim_param" field. + public const int DimParamFieldNumber = 2; + /// + /// namespace Shape + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DimParam + { + get { return valueCase_ == ValueOneofCase.DimParam ? (string)value_ : ""; } + set + { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + valueCase_ = ValueOneofCase.DimParam; + } + } + + /// Field number for the "denotation" field. + public const int DenotationFieldNumber = 3; + private string denotation_ = ""; + /// + /// Standard denotation can optionally be used to denote tensor + /// dimensions with standard semantic descriptions to ensure + /// that operations are applied to the correct axis of a tensor. + /// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + /// for pre-defined dimension denotations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Denotation + { + get { return denotation_; } + set + { + denotation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + private object value_; + /// Enum of possible cases for the "value" oneof. + public enum ValueOneofCase + { + None = 0, + DimValue = 1, + DimParam = 2, + } + private ValueOneofCase valueCase_ = ValueOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValueOneofCase ValueCase + { + get { return valueCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearValue() + { + valueCase_ = ValueOneofCase.None; + value_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as Dimension); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Dimension other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (DimValue != other.DimValue) return false; + if (DimParam != other.DimParam) return false; + if (Denotation != other.Denotation) return false; + if (ValueCase != other.ValueCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (valueCase_ == ValueOneofCase.DimValue) hash ^= DimValue.GetHashCode(); + if (valueCase_ == ValueOneofCase.DimParam) hash ^= DimParam.GetHashCode(); + if (Denotation.Length != 0) hash ^= Denotation.GetHashCode(); + hash ^= (int)valueCase_; + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as TensorShapeProto); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(TensorShapeProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if(!dim_.Equals(other.dim_)) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (valueCase_ == ValueOneofCase.DimValue) + { + output.WriteRawTag(8); + output.WriteInt64(DimValue); + } + if (valueCase_ == ValueOneofCase.DimParam) + { + output.WriteRawTag(18); + output.WriteString(DimParam); + } + if (Denotation.Length != 0) + { + output.WriteRawTag(26); + output.WriteString(Denotation); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - hash ^= dim_.GetHashCode(); - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (valueCase_ == ValueOneofCase.DimValue) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(DimValue); + } + if (valueCase_ == ValueOneofCase.DimParam) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DimParam); + } + if (Denotation.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Denotation); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Dimension other) + { + if (other == null) + { + return; + } + if (other.Denotation.Length != 0) + { + Denotation = other.Denotation; + } + switch (other.ValueCase) + { + case ValueOneofCase.DimValue: + DimValue = other.DimValue; + break; + case ValueOneofCase.DimParam: + DimParam = other.DimParam; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - dim_.WriteTo(output, _repeated_dim_codec); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: + { + DimValue = input.ReadInt64(); + break; + } + case 18: + { + DimParam = input.ReadString(); + break; + } + case 26: + { + Denotation = input.ReadString(); + break; + } + } + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - size += dim_.CalculateSize(_repeated_dim_codec); - return size; - } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(TensorShapeProto other) { - if (other == null) { - return; - } - dim_.Add(other.dim_); - } + } + #endregion - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - dim_.AddEntriesFrom(input, _repeated_dim_codec); - break; - } - } - } } - #region Nested types - /// Container for nested types declared in the TensorShapeProto message type. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static partial class Types { - public sealed partial class Dimension : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Dimension()); + /// + /// Types + /// + /// The standard ONNX data types. + /// + public sealed partial class TypeProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TypeProto()); + private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto.Descriptor.NestedTypes[0]; } + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[8]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Dimension() { - OnConstruction(); + public TypeProto() + { + OnConstruction(); } partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Dimension(Dimension other) : this() { - switch (other.ValueCase) { - case ValueOneofCase.DimValue: - DimValue = other.DimValue; - break; - case ValueOneofCase.DimParam: - DimParam = other.DimParam; - break; - } + public TypeProto(TypeProto other) : this() + { + denotation_ = other.denotation_; + switch (other.ValueCase) + { + case ValueOneofCase.TensorType: + TensorType = other.TensorType.Clone(); + break; + case ValueOneofCase.SequenceType: + SequenceType = other.SequenceType.Clone(); + break; + case ValueOneofCase.MapType: + MapType = other.MapType.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TypeProto Clone() + { + return new TypeProto(this); + } + /// Field number for the "tensor_type" field. + public const int TensorTypeFieldNumber = 1; + /// + /// The type of a tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor TensorType + { + get { return valueCase_ == ValueOneofCase.TensorType ? (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor)value_ : null; } + set + { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.TensorType; + } } + /// Field number for the "sequence_type" field. + public const int SequenceTypeFieldNumber = 4; + /// + /// The type of a sequence. + /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Dimension Clone() { - return new Dimension(this); + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence SequenceType + { + get { return valueCase_ == ValueOneofCase.SequenceType ? (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence)value_ : null; } + set + { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.SequenceType; + } } - /// Field number for the "dim_value" field. - public const int DimValueFieldNumber = 1; + /// Field number for the "map_type" field. + public const int MapTypeFieldNumber = 5; + /// + /// The type of a map. + /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long DimValue { - get { return valueCase_ == ValueOneofCase.DimValue ? (long) value_ : 0L; } - set { - value_ = value; - valueCase_ = ValueOneofCase.DimValue; - } + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map MapType + { + get { return valueCase_ == ValueOneofCase.MapType ? (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map)value_ : null; } + set + { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.MapType; + } } - /// Field number for the "dim_param" field. - public const int DimParamFieldNumber = 2; + /// Field number for the "denotation" field. + public const int DenotationFieldNumber = 6; + private string denotation_ = ""; /// - /// namespace Shape + /// An optional denotation can be used to denote the whole + /// type with a standard semantic description as to what is + /// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + /// for pre-defined type denotations. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string DimParam { - get { return valueCase_ == ValueOneofCase.DimParam ? (string) value_ : ""; } - set { - value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - valueCase_ = ValueOneofCase.DimParam; - } + public string Denotation + { + get { return denotation_; } + set + { + denotation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } } private object value_; /// Enum of possible cases for the "value" oneof. - public enum ValueOneofCase { - None = 0, - DimValue = 1, - DimParam = 2, + public enum ValueOneofCase + { + None = 0, + TensorType = 1, + SequenceType = 4, + MapType = 5, } private ValueOneofCase valueCase_ = ValueOneofCase.None; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ValueOneofCase ValueCase { - get { return valueCase_; } + public ValueOneofCase ValueCase + { + get { return valueCase_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void ClearValue() { - valueCase_ = ValueOneofCase.None; - value_ = null; + public void ClearValue() + { + valueCase_ = ValueOneofCase.None; + value_ = null; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as Dimension); + public override bool Equals(object other) + { + return Equals(other as TypeProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(Dimension other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (DimValue != other.DimValue) return false; - if (DimParam != other.DimParam) return false; - if (ValueCase != other.ValueCase) return false; - return true; + public bool Equals(TypeProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (!object.Equals(TensorType, other.TensorType)) return false; + if (!object.Equals(SequenceType, other.SequenceType)) return false; + if (!object.Equals(MapType, other.MapType)) return false; + if (Denotation != other.Denotation) return false; + if (ValueCase != other.ValueCase) return false; + return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (valueCase_ == ValueOneofCase.DimValue) hash ^= DimValue.GetHashCode(); - if (valueCase_ == ValueOneofCase.DimParam) hash ^= DimParam.GetHashCode(); - hash ^= (int) valueCase_; - return hash; + public override int GetHashCode() + { + int hash = 1; + if (valueCase_ == ValueOneofCase.TensorType) hash ^= TensorType.GetHashCode(); + if (valueCase_ == ValueOneofCase.SequenceType) hash ^= SequenceType.GetHashCode(); + if (valueCase_ == ValueOneofCase.MapType) hash ^= MapType.GetHashCode(); + if (Denotation.Length != 0) hash ^= Denotation.GetHashCode(); + hash ^= (int)valueCase_; + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (valueCase_ == ValueOneofCase.DimValue) { - output.WriteRawTag(8); - output.WriteInt64(DimValue); - } - if (valueCase_ == ValueOneofCase.DimParam) { - output.WriteRawTag(18); - output.WriteString(DimParam); - } + public void WriteTo(pb::CodedOutputStream output) + { + if (valueCase_ == ValueOneofCase.TensorType) + { + output.WriteRawTag(10); + output.WriteMessage(TensorType); + } + if (valueCase_ == ValueOneofCase.SequenceType) + { + output.WriteRawTag(34); + output.WriteMessage(SequenceType); + } + if (valueCase_ == ValueOneofCase.MapType) + { + output.WriteRawTag(42); + output.WriteMessage(MapType); + } + if (Denotation.Length != 0) + { + output.WriteRawTag(50); + output.WriteString(Denotation); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (valueCase_ == ValueOneofCase.DimValue) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(DimValue); - } - if (valueCase_ == ValueOneofCase.DimParam) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(DimParam); - } - return size; + public int CalculateSize() + { + int size = 0; + if (valueCase_ == ValueOneofCase.TensorType) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorType); + } + if (valueCase_ == ValueOneofCase.SequenceType) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SequenceType); + } + if (valueCase_ == ValueOneofCase.MapType) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MapType); + } + if (Denotation.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Denotation); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(Dimension other) { - if (other == null) { - return; - } - switch (other.ValueCase) { - case ValueOneofCase.DimValue: - DimValue = other.DimValue; - break; - case ValueOneofCase.DimParam: - DimParam = other.DimParam; - break; - } + public void MergeFrom(TypeProto other) + { + if (other == null) + { + return; + } + if (other.Denotation.Length != 0) + { + Denotation = other.Denotation; + } + switch (other.ValueCase) + { + case ValueOneofCase.TensorType: + if (TensorType == null) + { + TensorType = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor(); + } + TensorType.MergeFrom(other.TensorType); + break; + case ValueOneofCase.SequenceType: + if (SequenceType == null) + { + SequenceType = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence(); + } + SequenceType.MergeFrom(other.SequenceType); + break; + case ValueOneofCase.MapType: + if (MapType == null) + { + MapType = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map(); + } + MapType.MergeFrom(other.MapType); + break; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - DimValue = input.ReadInt64(); - break; - } - case 18: { - DimParam = input.ReadString(); - break; - } + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor subBuilder = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor(); + if (valueCase_ == ValueOneofCase.TensorType) + { + subBuilder.MergeFrom(TensorType); + } + input.ReadMessage(subBuilder); + TensorType = subBuilder; + break; + } + case 34: + { + global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence subBuilder = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence(); + if (valueCase_ == ValueOneofCase.SequenceType) + { + subBuilder.MergeFrom(SequenceType); + } + input.ReadMessage(subBuilder); + SequenceType = subBuilder; + break; + } + case 42: + { + global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map subBuilder = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map(); + if (valueCase_ == ValueOneofCase.MapType) + { + subBuilder.MergeFrom(MapType); + } + input.ReadMessage(subBuilder); + MapType = subBuilder; + break; + } + case 50: + { + Denotation = input.ReadString(); + break; + } + } } - } } - } - - } - #endregion - - } - - /// - /// Define the types. - /// - public sealed partial class TypeProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TypeProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[8]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TypeProto() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TypeProto(TypeProto other) : this() { - switch (other.ValueCase) { - case ValueOneofCase.TensorType: - TensorType = other.TensorType.Clone(); - break; - case ValueOneofCase.SequenceType: - SequenceType = other.SequenceType.Clone(); - break; - case ValueOneofCase.MapType: - MapType = other.MapType.Clone(); - break; - } + #region Nested types + /// Container for nested types declared in the TypeProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types + { + public sealed partial class Tensor : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Tensor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Descriptor.NestedTypes[0]; } + } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public TypeProto Clone() { - return new TypeProto(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Tensor() + { + OnConstruction(); + } - /// Field number for the "tensor_type" field. - public const int TensorTypeFieldNumber = 1; - /// - /// The type of a tensor. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor TensorType { - get { return valueCase_ == ValueOneofCase.TensorType ? (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor) value_ : null; } - set { - value_ = value; - valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.TensorType; - } - } + partial void OnConstruction(); - /// Field number for the "sequence_type" field. - public const int SequenceTypeFieldNumber = 4; - /// - /// The type of a sequence. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence SequenceType { - get { return valueCase_ == ValueOneofCase.SequenceType ? (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence) value_ : null; } - set { - value_ = value; - valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.SequenceType; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Tensor(Tensor other) : this() + { + elemType_ = other.elemType_; + Shape = other.shape_ != null ? other.Shape.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - /// Field number for the "map_type" field. - public const int MapTypeFieldNumber = 5; - /// - /// The type of a map. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map MapType { - get { return valueCase_ == ValueOneofCase.MapType ? (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map) value_ : null; } - set { - value_ = value; - valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.MapType; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Tensor Clone() + { + return new Tensor(this); + } - private object value_; - /// Enum of possible cases for the "value" oneof. - public enum ValueOneofCase { - None = 0, - TensorType = 1, - SequenceType = 4, - MapType = 5, - } - private ValueOneofCase valueCase_ = ValueOneofCase.None; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public ValueOneofCase ValueCase { - get { return valueCase_; } - } + /// Field number for the "elem_type" field. + public const int ElemTypeFieldNumber = 1; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType elemType_ = 0; + /// + /// This field MUST NOT have the value of UNDEFINED + /// This field MUST be present for this version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType ElemType + { + get { return elemType_; } + set + { + elemType_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void ClearValue() { - valueCase_ = ValueOneofCase.None; - value_ = null; - } + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 2; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto shape_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto Shape + { + get { return shape_; } + set + { + shape_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as TypeProto); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as Tensor); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(TypeProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (!object.Equals(TensorType, other.TensorType)) return false; - if (!object.Equals(SequenceType, other.SequenceType)) return false; - if (!object.Equals(MapType, other.MapType)) return false; - if (ValueCase != other.ValueCase) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Tensor other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (ElemType != other.ElemType) return false; + if (!object.Equals(Shape, other.Shape)) return false; + return Equals(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (valueCase_ == ValueOneofCase.TensorType) hash ^= TensorType.GetHashCode(); - if (valueCase_ == ValueOneofCase.SequenceType) hash ^= SequenceType.GetHashCode(); - if (valueCase_ == ValueOneofCase.MapType) hash ^= MapType.GetHashCode(); - hash ^= (int) valueCase_; - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (ElemType != 0) hash ^= ElemType.GetHashCode(); + if (shape_ != null) hash ^= Shape.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (valueCase_ == ValueOneofCase.TensorType) { - output.WriteRawTag(10); - output.WriteMessage(TensorType); - } - if (valueCase_ == ValueOneofCase.SequenceType) { - output.WriteRawTag(34); - output.WriteMessage(SequenceType); - } - if (valueCase_ == ValueOneofCase.MapType) { - output.WriteRawTag(42); - output.WriteMessage(MapType); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (ElemType != 0) + { + output.WriteRawTag(8); + output.WriteEnum((int)ElemType); + } + if (shape_ != null) + { + output.WriteRawTag(18); + output.WriteMessage(Shape); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (valueCase_ == ValueOneofCase.TensorType) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorType); - } - if (valueCase_ == ValueOneofCase.SequenceType) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(SequenceType); - } - if (valueCase_ == ValueOneofCase.MapType) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(MapType); - } - return size; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (ElemType != 0) + { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int)ElemType); + } + if (shape_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(TypeProto other) { - if (other == null) { - return; - } - switch (other.ValueCase) { - case ValueOneofCase.TensorType: - TensorType = other.TensorType; - break; - case ValueOneofCase.SequenceType: - SequenceType = other.SequenceType; - break; - case ValueOneofCase.MapType: - MapType = other.MapType; - break; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Tensor other) + { + if (other == null) + { + return; + } + if (other.ElemType != 0) + { + ElemType = other.ElemType; + } + if (other.shape_ != null) + { + if (shape_ == null) + { + shape_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: + { + elemType_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType)input.ReadEnum(); + break; + } + case 18: + { + if (shape_ == null) + { + shape_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto(); + } + input.ReadMessage(shape_); + break; + } + } + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor subBuilder = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Tensor(); - if (valueCase_ == ValueOneofCase.TensorType) { - subBuilder.MergeFrom(TensorType); - } - input.ReadMessage(subBuilder); - TensorType = subBuilder; - break; - } - case 34: { - global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence subBuilder = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Sequence(); - if (valueCase_ == ValueOneofCase.SequenceType) { - subBuilder.MergeFrom(SequenceType); - } - input.ReadMessage(subBuilder); - SequenceType = subBuilder; - break; - } - case 42: { - global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map subBuilder = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Types.Map(); - if (valueCase_ == ValueOneofCase.MapType) { - subBuilder.MergeFrom(MapType); - } - input.ReadMessage(subBuilder); - MapType = subBuilder; - break; - } - } - } - } + } - #region Nested types - /// Container for nested types declared in the TypeProto message type. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static partial class Types { - public sealed partial class Tensor : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Tensor()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + /// + /// repeated T + /// + public sealed partial class Sequence : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Sequence()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Descriptor.NestedTypes[1]; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Descriptor.NestedTypes[0]; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Sequence() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Tensor() { - OnConstruction(); - } + partial void OnConstruction(); - partial void OnConstruction(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Sequence(Sequence other) : this() + { + ElemType = other.elemType_ != null ? other.ElemType.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Tensor(Tensor other) : this() { - elemType_ = other.elemType_; - Shape = other.shape_ != null ? other.Shape.Clone() : null; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Sequence Clone() + { + return new Sequence(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Tensor Clone() { - return new Tensor(this); - } + /// Field number for the "elem_type" field. + public const int ElemTypeFieldNumber = 1; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto elemType_; + /// + /// The type and optional shape of each element of the sequence. + /// This field MUST be present for this version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto ElemType + { + get { return elemType_; } + set + { + elemType_ = value; + } + } - /// Field number for the "elem_type" field. - public const int ElemTypeFieldNumber = 1; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType elemType_ = 0; - /// - /// This field MUST NOT have the value of UNDEFINED - /// This field MUST be present for this version of the IR. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType ElemType { - get { return elemType_; } - set { - elemType_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as Sequence); + } - /// Field number for the "shape" field. - public const int ShapeFieldNumber = 2; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto shape_; - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto Shape { - get { return shape_; } - set { - shape_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Sequence other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (!object.Equals(ElemType, other.ElemType)) return false; + return Equals(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as Tensor); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (elemType_ != null) hash ^= ElemType.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(Tensor other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (ElemType != other.ElemType) return false; - if (!object.Equals(Shape, other.Shape)) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (ElemType != 0) hash ^= ElemType.GetHashCode(); - if (shape_ != null) hash ^= Shape.GetHashCode(); - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (elemType_ != null) + { + output.WriteRawTag(10); + output.WriteMessage(ElemType); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (elemType_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ElemType); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (ElemType != 0) { - output.WriteRawTag(8); - output.WriteEnum((int) ElemType); - } - if (shape_ != null) { - output.WriteRawTag(18); - output.WriteMessage(Shape); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Sequence other) + { + if (other == null) + { + return; + } + if (other.elemType_ != null) + { + if (elemType_ == null) + { + elemType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + } + ElemType.MergeFrom(other.ElemType); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (ElemType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ElemType); - } - if (shape_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); - } - return size; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + if (elemType_ == null) + { + elemType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + } + input.ReadMessage(elemType_); + break; + } + } + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(Tensor other) { - if (other == null) { - return; - } - if (other.ElemType != 0) { - ElemType = other.ElemType; - } - if (other.shape_ != null) { - if (shape_ == null) { - shape_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto(); } - Shape.MergeFrom(other.Shape); - } - } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - elemType_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType) input.ReadEnum(); - break; - } - case 18: { - if (shape_ == null) { - shape_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorShapeProto(); + /// + /// map<K,V> + /// + public sealed partial class Map : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Map()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Descriptor.NestedTypes[2]; } } - input.ReadMessage(shape_); - break; - } - } - } - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } + } - /// - /// repeated T - /// - public sealed partial class Sequence : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Sequence()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Map() + { + OnConstruction(); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Descriptor.NestedTypes[1]; } - } + partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Map(Map other) : this() + { + keyType_ = other.keyType_; + ValueType = other.valueType_ != null ? other.ValueType.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Sequence() { - OnConstruction(); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Map Clone() + { + return new Map(this); + } - partial void OnConstruction(); + /// Field number for the "key_type" field. + public const int KeyTypeFieldNumber = 1; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType keyType_ = 0; + /// + /// This field MUST be present for this version of the IR. + /// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType KeyType + { + get { return keyType_; } + set + { + keyType_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Sequence(Sequence other) : this() { - ElemType = other.elemType_ != null ? other.ElemType.Clone() : null; - } + /// Field number for the "value_type" field. + public const int ValueTypeFieldNumber = 2; + private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto valueType_; + /// + /// This field MUST be present for this version of the IR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto ValueType + { + get { return valueType_; } + set + { + valueType_ = value; + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Sequence Clone() { - return new Sequence(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) + { + return Equals(other as Map); + } - /// Field number for the "elem_type" field. - public const int ElemTypeFieldNumber = 1; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto elemType_; - /// - /// The type and optional shape of each element of the sequence. - /// This field MUST be present for this version of the IR. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto ElemType { - get { return elemType_; } - set { - elemType_ = value; - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Map other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (KeyType != other.KeyType) return false; + if (!object.Equals(ValueType, other.ValueType)) return false; + return Equals(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as Sequence); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() + { + int hash = 1; + if (KeyType != 0) hash ^= KeyType.GetHashCode(); + if (valueType_ != null) hash ^= ValueType.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(Sequence other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (!object.Equals(ElemType, other.ElemType)) return false; - return true; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (elemType_ != null) hash ^= ElemType.GetHashCode(); - return hash; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) + { + if (KeyType != 0) + { + output.WriteRawTag(8); + output.WriteEnum((int)KeyType); + } + if (valueType_ != null) + { + output.WriteRawTag(18); + output.WriteMessage(ValueType); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() + { + int size = 0; + if (KeyType != 0) + { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int)KeyType); + } + if (valueType_ != null) + { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValueType); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (elemType_ != null) { - output.WriteRawTag(10); - output.WriteMessage(ElemType); - } - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Map other) + { + if (other == null) + { + return; + } + if (other.KeyType != 0) + { + KeyType = other.KeyType; + } + if (other.valueType_ != null) + { + if (valueType_ == null) + { + valueType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + } + ValueType.MergeFrom(other.ValueType); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (elemType_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(ElemType); - } - return size; - } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: + { + keyType_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType)input.ReadEnum(); + break; + } + case 18: + { + if (valueType_ == null) + { + valueType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + } + input.ReadMessage(valueType_); + break; + } + } + } + } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(Sequence other) { - if (other == null) { - return; - } - if (other.elemType_ != null) { - if (elemType_ == null) { - elemType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); } - ElemType.MergeFrom(other.ElemType); - } - } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - if (elemType_ == null) { - elemType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); - } - input.ReadMessage(elemType_); - break; - } - } - } } + #endregion - } + } - /// - /// map<K,V> - /// - public sealed partial class Map : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Map()); + /// + /// Operator Sets + /// + /// OperatorSets are uniquely identified by a (domain, opset_version) pair. + /// + public sealed partial class OperatorSetIdProto : pb::IMessage + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OperatorSetIdProto()); + private pb::UnknownFieldSet _unknownFields; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } + public static pb::MessageParser Parser { get { return _parser; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto.Descriptor.NestedTypes[2]; } + public static pbr::MessageDescriptor Descriptor + { + get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[9]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } + pbr::MessageDescriptor pb::IMessage.Descriptor + { + get { return Descriptor; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Map() { - OnConstruction(); + public OperatorSetIdProto() + { + OnConstruction(); } partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Map(Map other) : this() { - keyType_ = other.keyType_; - ValueType = other.valueType_ != null ? other.ValueType.Clone() : null; + public OperatorSetIdProto(OperatorSetIdProto other) : this() + { + domain_ = other.domain_; + version_ = other.version_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public Map Clone() { - return new Map(this); + public OperatorSetIdProto Clone() + { + return new OperatorSetIdProto(this); } - /// Field number for the "key_type" field. - public const int KeyTypeFieldNumber = 1; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType keyType_ = 0; + /// Field number for the "domain" field. + public const int DomainFieldNumber = 1; + private string domain_ = ""; /// - /// This field MUST be present for this version of the IR. - /// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + /// The domain of the operator set being identified. + /// The empty string ("") or absence of this field implies the operator + /// set that is defined as part of the ONNX specification. + /// This field MUST be present in this version of the IR when referring to any other operator set. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType KeyType { - get { return keyType_; } - set { - keyType_ = value; - } + public string Domain + { + get { return domain_; } + set + { + domain_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } } - /// Field number for the "value_type" field. - public const int ValueTypeFieldNumber = 2; - private global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto valueType_; + /// Field number for the "version" field. + public const int VersionFieldNumber = 2; + private long version_; /// - /// This field MUST be present for this version of the IR. + /// The version of the operator set being identified. + /// This field MUST be present in this version of the IR. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto ValueType { - get { return valueType_; } - set { - valueType_ = value; - } + public long Version + { + get { return version_; } + set + { + version_ = value; + } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as Map); + public override bool Equals(object other) + { + return Equals(other as OperatorSetIdProto); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(Map other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (KeyType != other.KeyType) return false; - if (!object.Equals(ValueType, other.ValueType)) return false; - return true; + public bool Equals(OperatorSetIdProto other) + { + if (ReferenceEquals(other, null)) + { + return false; + } + if (ReferenceEquals(other, this)) + { + return true; + } + if (Domain != other.Domain) return false; + if (Version != other.Version) return false; + return Equals(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (KeyType != 0) hash ^= KeyType.GetHashCode(); - if (valueType_ != null) hash ^= ValueType.GetHashCode(); - return hash; + public override int GetHashCode() + { + int hash = 1; + if (Domain.Length != 0) hash ^= Domain.GetHashCode(); + if (Version != 0L) hash ^= Version.GetHashCode(); + if (_unknownFields != null) + { + hash ^= _unknownFields.GetHashCode(); + } + return hash; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); + public override string ToString() + { + return pb::JsonFormatter.ToDiagnosticString(this); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (KeyType != 0) { - output.WriteRawTag(8); - output.WriteEnum((int) KeyType); - } - if (valueType_ != null) { - output.WriteRawTag(18); - output.WriteMessage(ValueType); - } + public void WriteTo(pb::CodedOutputStream output) + { + if (Domain.Length != 0) + { + output.WriteRawTag(10); + output.WriteString(Domain); + } + if (Version != 0L) + { + output.WriteRawTag(16); + output.WriteInt64(Version); + } + if (_unknownFields != null) + { + _unknownFields.WriteTo(output); + } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (KeyType != 0) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) KeyType); - } - if (valueType_ != null) { - size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValueType); - } - return size; + public int CalculateSize() + { + int size = 0; + if (Domain.Length != 0) + { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Domain); + } + if (Version != 0L) + { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Version); + } + if (_unknownFields != null) + { + size += _unknownFields.CalculateSize(); + } + return size; } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(Map other) { - if (other == null) { - return; - } - if (other.KeyType != 0) { - KeyType = other.KeyType; - } - if (other.valueType_ != null) { - if (valueType_ == null) { - valueType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + public void MergeFrom(OperatorSetIdProto other) + { + if (other == null) + { + return; + } + if (other.Domain.Length != 0) + { + Domain = other.Domain; } - ValueType.MergeFrom(other.ValueType); - } + if (other.Version != 0L) + { + Version = other.Version; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 8: { - keyType_ = (global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TensorProto.Types.DataType) input.ReadEnum(); - break; - } - case 18: { - if (valueType_ == null) { - valueType_ = new global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.TypeProto(); + public void MergeFrom(pb::CodedInputStream input) + { + uint tag; + while ((tag = input.ReadTag()) != 0) + { + switch (tag) + { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + { + Domain = input.ReadString(); + break; + } + case 16: + { + Version = input.ReadInt64(); + break; + } } - input.ReadMessage(valueType_); - break; - } } - } } - } - - } - #endregion - - } - - /// - /// OperatorSets are uniquely identified by a (domain, opset_version) pair. - /// - public sealed partial class OperatorSetIdProto : pb::IMessage { - private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OperatorSetIdProto()); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pb::MessageParser Parser { get { return _parser; } } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public static pbr::MessageDescriptor Descriptor { - get { return global::Microsoft.ML.Runtime.UniversalModelFormat.Onnx.OnnxMlReflection.Descriptor.MessageTypes[9]; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - pbr::MessageDescriptor pb::IMessage.Descriptor { - get { return Descriptor; } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public OperatorSetIdProto() { - OnConstruction(); - } - - partial void OnConstruction(); - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public OperatorSetIdProto(OperatorSetIdProto other) : this() { - domain_ = other.domain_; - version_ = other.version_; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public OperatorSetIdProto Clone() { - return new OperatorSetIdProto(this); - } - - /// Field number for the "domain" field. - public const int DomainFieldNumber = 1; - private string domain_ = ""; - /// - /// The domain of the operator set being identified. - /// The empty string ("") or absence of this field implies the operator - /// set that is defined as part of the ONNX specification. - /// This field MUST be present in this version of the IR when referring to any other operator set. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public string Domain { - get { return domain_; } - set { - domain_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); - } - } - - /// Field number for the "version" field. - public const int VersionFieldNumber = 2; - private long version_; - /// - /// The version of the operator set being identified. - /// This field MUST be present in this version of the IR. - /// - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public long Version { - get { return version_; } - set { - version_ = value; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override bool Equals(object other) { - return Equals(other as OperatorSetIdProto); } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public bool Equals(OperatorSetIdProto other) { - if (ReferenceEquals(other, null)) { - return false; - } - if (ReferenceEquals(other, this)) { - return true; - } - if (Domain != other.Domain) return false; - if (Version != other.Version) return false; - return true; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override int GetHashCode() { - int hash = 1; - if (Domain.Length != 0) hash ^= Domain.GetHashCode(); - if (Version != 0L) hash ^= Version.GetHashCode(); - return hash; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public override string ToString() { - return pb::JsonFormatter.ToDiagnosticString(this); - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void WriteTo(pb::CodedOutputStream output) { - if (Domain.Length != 0) { - output.WriteRawTag(10); - output.WriteString(Domain); - } - if (Version != 0L) { - output.WriteRawTag(16); - output.WriteInt64(Version); - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public int CalculateSize() { - int size = 0; - if (Domain.Length != 0) { - size += 1 + pb::CodedOutputStream.ComputeStringSize(Domain); - } - if (Version != 0L) { - size += 1 + pb::CodedOutputStream.ComputeInt64Size(Version); - } - return size; - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(OperatorSetIdProto other) { - if (other == null) { - return; - } - if (other.Domain.Length != 0) { - Domain = other.Domain; - } - if (other.Version != 0L) { - Version = other.Version; - } - } - - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public void MergeFrom(pb::CodedInputStream input) { - uint tag; - while ((tag = input.ReadTag()) != 0) { - switch(tag) { - default: - input.SkipLastField(); - break; - case 10: { - Domain = input.ReadString(); - break; - } - case 16: { - Version = input.ReadInt64(); - break; - } - } - } - } - - } - - #endregion + #endregion } diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md index c6fa3ecb6e..ee9f02f2c4 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md @@ -11,4 +11,4 @@ protoc.exe -I="E:\protobuf-csharp-port\lib" --csharp_out="E:\protobuf-csharp-port\lib" "E:\protobuf-csharp-port\lib\onnx-ml.proto3" ``` -## The proto3 file is current as of 02/07/2018. \ No newline at end of file +## The proto3 file is current as of 06/01/2018 and generated from onnx-ml.proto3 based on https://github.com/onnx/onnx/pull/1052 \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index 9ccd2f1615..f87aa92b29 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Data; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; From 17a738a64d7be07d9dff444d08fc7f6f09820857 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 1 Jun 2018 11:11:46 -0700 Subject: [PATCH 15/22] update baselines and regenerate csharp APIs. --- src/Microsoft.ML/CSharpApi.cs | 25 -------- .../Common/EntryPoints/core_manifest.json | 60 ------------------- 2 files changed, 85 deletions(-) diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index ddabfeec31..2107269865 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2799,31 +2799,6 @@ public sealed partial class OnnxConverter /// public string DataFile { get; set; } - /// - /// Model file to load - /// - public string InputModelFile { get; set; } - - /// - /// Load transforms from model file? - /// - public bool? LoadTransforms { get; set; } - - /// - /// Random seed - /// - public int? RandomSeed { get; set; } - - /// - /// Verbose? - /// - public bool? Verbose { get; set; } - - /// - /// Desired degree of parallelism in the data pipeline - /// - public int? Parallel { get; set; } - public sealed class Output { diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 4a494634aa..a1e3ccf98a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2140,66 +2140,6 @@ "Required": true, "SortOrder": 10.0, "IsNullable": false - }, - { - "Name": "InputModelFile", - "Type": "String", - "Desc": "Model file to load", - "Aliases": [ - "in" - ], - "Required": false, - "SortOrder": 90.0, - "IsNullable": false, - "Default": null - }, - { - "Name": "LoadTransforms", - "Type": "Bool", - "Desc": "Load transforms from model file?", - "Aliases": [ - "loadTrans" - ], - "Required": false, - "SortOrder": 91.0, - "IsNullable": true, - "Default": null - }, - { - "Name": "RandomSeed", - "Type": "Int", - "Desc": "Random seed", - "Aliases": [ - "seed" - ], - "Required": false, - "SortOrder": 101.0, - "IsNullable": true, - "Default": null - }, - { - "Name": "Verbose", - "Type": "Bool", - "Desc": "Verbose?", - "Aliases": [ - "v" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": true, - "Default": null - }, - { - "Name": "Parallel", - "Type": "Int", - "Desc": "Desired degree of parallelism in the data pipeline", - "Aliases": [ - "n" - ], - "Required": false, - "SortOrder": 150.0, - "IsNullable": true, - "Default": null } ], "Outputs": [] From 5090d24f663294038aabaff20c473cd7eabe47d2 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 1 Jun 2018 13:53:54 -0700 Subject: [PATCH 16/22] Add link to the commit in ONNX MD file. --- src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md index ee9f02f2c4..c6b8a5283a 100644 --- a/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md +++ b/src/Microsoft.ML.UniversalModelFormat/Onnx/OnnxMl.md @@ -11,4 +11,4 @@ protoc.exe -I="E:\protobuf-csharp-port\lib" --csharp_out="E:\protobuf-csharp-port\lib" "E:\protobuf-csharp-port\lib\onnx-ml.proto3" ``` -## The proto3 file is current as of 06/01/2018 and generated from onnx-ml.proto3 based on https://github.com/onnx/onnx/pull/1052 \ No newline at end of file +## The proto3 file is current as of 06/01/2018 and generated from onnx-ml.proto3 based on the following commit https://github.com/onnx/onnx/commit/33e9cd4182fe468675241fba4ae8a16c2f0bd82f \ No newline at end of file From 4cfed38576ae98605898af7f612f2014ca4594f4 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 1 Jun 2018 13:56:37 -0700 Subject: [PATCH 17/22] PR feedback. --- src/Microsoft.ML.Console/Console.cs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Console/Console.cs b/src/Microsoft.ML.Console/Console.cs index 7e5f0dbe9b..1959dfdfea 100644 --- a/src/Microsoft.ML.Console/Console.cs +++ b/src/Microsoft.ML.Console/Console.cs @@ -6,16 +6,7 @@ namespace Microsoft.ML.Runtime.Tools.Console { public static class Console { - public static int Main(string[] args) - { - string all = string.Join(" ", args); - return Maml.MainAll(all); - } + public static int Main(string[] args) => Maml.Main(args); - public static unsafe int MainRaw(char* psz) - { - string args = new string(psz); - return Maml.MainAll(args); - } } } From 0a13ad7afdb9c225f609fc6cbb6428e9bec518ce Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 1 Jun 2018 13:58:27 -0700 Subject: [PATCH 18/22] cleanup. --- src/Microsoft.ML.Console/Console.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Microsoft.ML.Console/Console.cs b/src/Microsoft.ML.Console/Console.cs index 1959dfdfea..12e6254cce 100644 --- a/src/Microsoft.ML.Console/Console.cs +++ b/src/Microsoft.ML.Console/Console.cs @@ -7,6 +7,5 @@ namespace Microsoft.ML.Runtime.Tools.Console public static class Console { public static int Main(string[] args) => Maml.Main(args); - } } From 2ab729f55530b17f24895e0937cebda81b4dc5a9 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 4 Jun 2018 00:00:23 -0700 Subject: [PATCH 19/22] Add missing attributes to ONNX model. --- .../Model/Onnx/OnnxContext.cs | 11 ++++++++-- src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs | 13 ++++++++--- .../Model/Onnx/SaveOnnxCommand.cs | 5 ++++- .../Prediction/Calibrator.cs | 6 ++--- .../BreastCancer/SaveModelToOnnxTest.json | 22 +++++++++++++++---- 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index fb3f58291c..9759ce0c6c 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -23,10 +23,14 @@ public sealed class OnnxContext private readonly HashSet _variableMap; private readonly HashSet _nodeNames; private readonly string _name; + private readonly string _producerName; private readonly IHost _host; private readonly string _domain; + private readonly string _producerVersion; + private readonly long _modelVersion; - public OnnxContext(IHostEnvironment env, string name, string domain) + public OnnxContext(IHostEnvironment env, string name, string producerName, + string producerVersion, long modelVersion, string domain) { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValue(name, nameof(name)); @@ -41,6 +45,9 @@ public OnnxContext(IHostEnvironment env, string name, string domain) _variableMap = new HashSet(); _nodeNames = new HashSet(); _name = name; + _producerName = producerName; + _producerVersion = producerVersion; + _modelVersion = modelVersion; _domain = domain; } @@ -234,6 +241,6 @@ public void AddInputVariable(ColumnType type, string colName) /// Makes the ONNX model based on the context. /// public ModelProto MakeModel() - => OnnxUtils.MakeModel(_nodes, _name, _name, _domain, _inputs, _outputs, _intermediateValues); + => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues); } } diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs index 667787c990..5b63c661c2 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs @@ -153,7 +153,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0); - public static NodeProto MakeNode(string opType, List inputs, List outputs, string name) + public static NodeProto MakeNode(string opType, List inputs, List outputs, string name, string domain = null) { Contracts.CheckNonEmpty(opType, nameof(opType)); Contracts.CheckValue(inputs, nameof(inputs)); @@ -165,7 +165,7 @@ public static NodeProto MakeNode(string opType, List inputs, List nodes, string producerName, string name, string domain, List inputs, + public static ModelProto MakeModel(List nodes, string producerName, string name, + string domain, string producerVersion, long modelVersion, List inputs, List outputs, List intermediateValues) { Contracts.CheckValue(nodes, nameof(nodes)); @@ -261,10 +262,16 @@ public static ModelProto MakeModel(List nodes, string producerName, s Contracts.CheckNonEmpty(producerName, nameof(producerName)); Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckNonEmpty(domain, nameof(domain)); + Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion)); var model = new ModelProto(); model.Domain = domain; model.ProducerName = producerName; + model.ProducerVersion = producerVersion; + model.IrVersion = (long)UniversalModelFormat.Onnx.Version.IrVersion; + model.ModelVersion = modelVersion; + model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 }); + model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx", Version = 6 }); model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs index 614f5f7a92..b760eb7718 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs @@ -68,6 +68,9 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly HashSet _inputsToDrop; private readonly HashSet _outputsToDrop; private readonly ITransformModel _model; + private const string ProducerName = "ML.Net"; + private const string ProducerVersion = "0.2.0.0000"; + private const long ModelVersion = 0; public SaveOnnxCommand(IHostEnvironment env, Arguments args) : base(env, args, LoadName) @@ -161,7 +164,7 @@ private void Run(IChannel ch) GetPipe(ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); - var ctx = new OnnxContext(Host, _name, _domain); + var ctx = new OnnxContext(Host, _name, ProducerName, ProducerVersion, ModelVersion, _domain); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 9fe7ed70af..efa52d2ff6 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1444,9 +1444,8 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str string opType = "Affine"; string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true); var node = OnnxUtils.MakeNode(opType, new List { scoreProbablityColumnNames[0] }, - new List { linearOutput }, ctx.GetNodeName(opType)); + new List { linearOutput }, ctx.GetNodeName(opType), "ai.onnx"); - node.Domain = ""; OnnxUtils.NodeAddAttributes(node, "alpha", ParamA * -1); OnnxUtils.NodeAddAttributes(node, "beta", -0.0000001); @@ -1454,9 +1453,8 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, str opType = "Sigmoid"; node = OnnxUtils.MakeNode(opType, new List { linearOutput }, - new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType)); + new List { scoreProbablityColumnNames[1] }, ctx.GetNodeName(opType), "ai.onnx"); - node.Domain = ""; ctx.AddNode(node); return true; diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json index eb76e1527e..24ac402334 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json @@ -1,5 +1,7 @@ { - "producerName": "SaveModelToOnnxTest", + "irVersion": "3", + "producerName": "ML.Net", + "producerVersion": "0.2.0.0000", "domain": "Onnx", "graph": { "node": [ @@ -590,7 +592,8 @@ "f": -1E-07, "type": "FLOAT" } - ] + ], + "domain": "ai.onnx" }, { "input": [ @@ -600,7 +603,8 @@ "Probability" ], "name": "Sigmoid", - "opType": "Sigmoid" + "opType": "Sigmoid", + "domain": "ai.onnx" }, { "input": [ @@ -698,5 +702,15 @@ } } ] - } + }, + "opsetImport": [ + { + "domain": "ai.onnx.ml", + "version": "1" + }, + { + "domain": "ai.onnx", + "version": "6" + } + ] } \ No newline at end of file From 3dfd81f09366201427457e0b4e274a24868e436f Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 4 Jun 2018 22:07:06 -0700 Subject: [PATCH 20/22] cleanup. --- src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs | 2 +- .../BinaryClassification/BreastCancer/SaveModelToOnnxTest.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs index b760eb7718..d2dfc93fde 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs @@ -68,7 +68,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase private readonly HashSet _inputsToDrop; private readonly HashSet _outputsToDrop; private readonly ITransformModel _model; - private const string ProducerName = "ML.Net"; + private const string ProducerName = "ML.NET"; private const string ProducerVersion = "0.2.0.0000"; private const long ModelVersion = 0; diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json index 24ac402334..3afd6dc171 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/SaveModelToOnnxTest.json @@ -1,6 +1,6 @@ { "irVersion": "3", - "producerName": "ML.Net", + "producerName": "ML.NET", "producerVersion": "0.2.0.0000", "domain": "Onnx", "graph": { From 356233132ff6016879d9137f5f7977f5d175bd16 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 6 Jun 2018 11:10:26 -0700 Subject: [PATCH 21/22] add more commands. --- src/Microsoft.ML.Console/Microsoft.ML.Console.csproj | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index 9fe1010c2c..8039371695 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -12,7 +12,9 @@ + + \ No newline at end of file From d20f1a48343b74379f782b19413e924ed3ae5b95 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 6 Jun 2018 11:35:24 -0700 Subject: [PATCH 22/22] cleanup. --- test/Microsoft.ML.Tests/OnnxTests.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index f87aa92b29..6910aba70b 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -69,9 +69,6 @@ public void BinaryClassificationSaveModelToOnnxTest() var model = pipeline.Train(); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer"); - var modelOutpath = GetOutputPath(subDir, "SaveModelToOnnxTest.zip"); - DeleteOutputPath(modelOutpath); - var onnxPath = GetOutputPath(subDir, "SaveModelToOnnxTest.pb"); DeleteOutputPath(onnxPath);