Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Back to built-in EmptyDataView
  • Loading branch information
FranklinWhale authored Nov 12, 2022
commit 7de47eeeb86eff84f44bbca84733a0bb718e483d
61 changes: 6 additions & 55 deletions src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -80,7 +79,7 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(inputData.Schema)).WriteTo(stream);
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(catalog.GetEnvironment(), inputData.Schema)).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
Expand All @@ -92,7 +91,7 @@ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransform
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(inputData.Schema), opSetVersion).WriteTo(stream);
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(catalog.GetEnvironment(), inputData.Schema), opSetVersion).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
Expand All @@ -104,7 +103,7 @@ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransform
/// <param name="outputColumns">List of output columns we want to keep.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream, params string[] outputColumns) =>
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(inputData.Schema), outputColumns).WriteTo(stream);
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(catalog.GetEnvironment(), inputData.Schema), outputColumns).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
Expand All @@ -115,7 +114,7 @@ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransform
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(inputSchema)).WriteTo(stream);
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(catalog.GetEnvironment(), inputSchema)).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
Expand All @@ -127,7 +126,7 @@ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransform
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(inputSchema), opSetVersion).WriteTo(stream);
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(catalog.GetEnvironment(), inputSchema), opSetVersion).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
Expand All @@ -139,54 +138,6 @@ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransform
/// <param name="outputColumns">List of output columns we want to keep.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, Stream stream, params string[] outputColumns) =>
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(inputSchema), outputColumns).WriteTo(stream);

private sealed class EmptyDataView : IDataView
{
private readonly DataViewSchema _schema;

public EmptyDataView(DataViewSchema schema)
{
_schema = schema;
}

public DataViewSchema Schema => _schema;

public bool CanShuffle => true;

public long? GetRowCount() => 0L;

public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
=> new EmptyDataViewRowCursor(Schema);

public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
=> Array.Empty<DataViewRowCursor>();
}

private sealed class EmptyDataViewRowCursor : DataViewRowCursor
{
private readonly DataViewSchema _schema;

public EmptyDataViewRowCursor(DataViewSchema schema)
{
_schema = schema;
}

public override DataViewSchema Schema => _schema;

public override long Position => -1L;

public override bool IsColumnActive(DataViewSchema.Column column) => false;

public override bool MoveNext() => false;

public override long Batch => 0L;

public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
=> throw new InvalidOperationException();

public override ValueGetter<DataViewRowId> GetIdGetter()
=> throw new InvalidOperationException();
}
ConvertToOnnxProtobuf(catalog, transform, new EmptyDataView(catalog.GetEnvironment(), inputSchema), outputColumns).WriteTo(stream);
}
}