Skip to content

Commit 8e77d0f

Browse files
authored
Converted Schema to a class (dotnet#1167)
* Created a Schema class for eager schema. Converted existing row mappers to use Schema.
1 parent bee7f17 commit 8e77d0f

File tree

151 files changed

+1447
-1241
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

151 files changed

+1447
-1241
lines changed

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

Lines changed: 29 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.Internal.Utilities;
7+
using Microsoft.ML.Runtime.Model;
58
using System;
69
using System.Collections.Generic;
710
using System.IO;
811
using System.Linq;
912
using System.Reflection;
10-
using Microsoft.ML.Runtime.Data;
11-
using Microsoft.ML.Runtime.Internal.Utilities;
12-
using Microsoft.ML.Runtime.Model;
1313

1414
namespace Microsoft.ML.Runtime.Api
1515
{
@@ -83,7 +83,7 @@ public sealed class InputRow<TRow> : InputRowBase<TRow>, IRowBackedBy<TRow>
8383
public override long Position => _position;
8484

8585
public InputRow(IHostEnvironment env, InternalSchemaDefinition schemaDef)
86-
: base(env, new SchemaProxy(schemaDef), schemaDef, MakePeeks(schemaDef), c => true)
86+
: base(env, new Schema(GetSchemaColumns(schemaDef)), schemaDef, MakePeeks(schemaDef), c => true)
8787
{
8888
_position = -1;
8989
}
@@ -136,11 +136,11 @@ public abstract class InputRowBase<TRow> : IRow
136136

137137
public long Batch => 0;
138138

139-
public ISchema Schema { get; }
139+
public Schema Schema { get; }
140140

141141
public abstract long Position { get; }
142142

143-
public InputRowBase(IHostEnvironment env, ISchema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func<int, bool> predicate)
143+
public InputRowBase(IHostEnvironment env, Schema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func<int, bool> predicate)
144144
{
145145
Contracts.AssertValue(env);
146146
Host = env.Register("Row");
@@ -326,27 +326,30 @@ public abstract class DataViewBase<TRow> : IDataView
326326
{
327327
protected readonly IHost Host;
328328

329-
private readonly SchemaProxy _schema;
329+
private readonly Schema _schema;
330+
private readonly InternalSchemaDefinition _schemaDefn;
330331

331332
// The array of generated methods that extract the fields of the current row object.
332333
private readonly Delegate[] _peeks;
333334

334335
public abstract bool CanShuffle { get; }
335336

336-
public ISchema Schema => _schema;
337+
public Schema Schema => _schema;
337338

338339
protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefinition schemaDefn)
339340
{
340341
Contracts.AssertValue(env);
341342
env.AssertNonWhiteSpace(name);
342343
Host = env.Register(name);
343344
Host.AssertValue(schemaDefn);
344-
_schema = new SchemaProxy(schemaDefn);
345-
int n = _schema.SchemaDefn.Columns.Length;
345+
346+
_schemaDefn = schemaDefn;
347+
_schema = new Schema(GetSchemaColumns(schemaDefn));
348+
int n = schemaDefn.Columns.Length;
346349
_peeks = new Delegate[n];
347350
for (var i = 0; i < n; i++)
348351
{
349-
var currentColumn = _schema.SchemaDefn.Columns[i];
352+
var currentColumn = schemaDefn.Columns[i];
350353
_peeks[i] = currentColumn.IsComputed
351354
? currentColumn.Generator
352355
: ApiUtils.GeneratePeek<DataViewBase<TRow>, TRow>(currentColumn);
@@ -381,7 +384,7 @@ public abstract class DataViewCursorBase : InputRowBase<TRow>, IRowCursor
381384

382385
protected DataViewCursorBase(IHostEnvironment env, DataViewBase<TRow> dataView,
383386
Func<int, bool> predicate)
384-
: base(env, dataView.Schema, dataView._schema.SchemaDefn, dataView._peeks, predicate)
387+
: base(env, dataView.Schema, dataView._schemaDefn, dataView._peeks, predicate)
385388
{
386389
Contracts.AssertValue(env);
387390
Ch = env.Start("Cursor");
@@ -747,72 +750,20 @@ protected override bool MoveManyCore(long count)
747750
}
748751
}
749752

750-
private sealed class SchemaProxy : ISchema
753+
internal static Schema.Column[] GetSchemaColumns(InternalSchemaDefinition schemaDefn)
751754
{
752-
public readonly InternalSchemaDefinition SchemaDefn;
753-
754-
public SchemaProxy(InternalSchemaDefinition schemaDefn)
755-
{
756-
SchemaDefn = schemaDefn;
757-
}
758-
759-
public int ColumnCount
760-
{
761-
get { return SchemaDefn.Columns.Length; }
762-
}
763-
764-
public bool TryGetColumnIndex(string name, out int col)
765-
{
766-
col = Array.FindIndex(SchemaDefn.Columns, c => c.ColumnName == name);
767-
return col >= 0;
768-
}
769-
770-
public string GetColumnName(int col)
755+
Contracts.AssertValue(schemaDefn);
756+
var columns = new Schema.Column[schemaDefn.Columns.Length];
757+
for (int i = 0; i < columns.Length; i++)
771758
{
772-
CheckColumnInRange(col);
773-
return SchemaDefn.Columns[col].ColumnName;
759+
var col = schemaDefn.Columns[i];
760+
var meta = new Schema.Metadata.Builder();
761+
foreach (var kvp in col.Metadata)
762+
meta.Add(new Schema.Column(kvp.Value.Kind, kvp.Value.MetadataType, null), kvp.Value.GetGetterDelegate());
763+
columns[i] = new Schema.Column(col.ColumnName, col.ColumnType, meta.GetMetadata());
774764
}
775765

776-
public ColumnType GetColumnType(int col)
777-
{
778-
CheckColumnInRange(col);
779-
return SchemaDefn.Columns[col].ColumnType;
780-
}
781-
782-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
783-
{
784-
CheckColumnInRange(col);
785-
var columnMetadata = SchemaDefn.Columns[col].Metadata;
786-
if (columnMetadata == null)
787-
yield break;
788-
foreach (var kvp in columnMetadata.Select(x => new KeyValuePair<string, ColumnType>(x.Key, x.Value.MetadataType)))
789-
yield return kvp;
790-
}
791-
792-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
793-
{
794-
if (string.IsNullOrEmpty(kind))
795-
throw MetadataUtils.ExceptGetMetadata();
796-
CheckColumnInRange(col);
797-
var column = SchemaDefn.Columns[col];
798-
return column.Metadata.ContainsKey(kind) ? column.Metadata[kind].MetadataType : null;
799-
}
800-
801-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
802-
{
803-
var metadataType = GetMetadataTypeOrNull(kind, col);
804-
if (metadataType == null)
805-
throw MetadataUtils.ExceptGetMetadata();
806-
807-
var metadata = SchemaDefn.Columns[col].Metadata[kind];
808-
metadata.GetGetter<TValue>()(ref value);
809-
}
810-
811-
private void CheckColumnInRange(int columnIndex)
812-
{
813-
if (columnIndex < 0 || columnIndex >= SchemaDefn.Columns.Length)
814-
throw Contracts.Except("Column index must be between 0 and {0}", SchemaDefn.Columns.Length);
815-
}
766+
return columns;
816767
}
817768
}
818769

@@ -833,6 +784,8 @@ public abstract partial class MetadataInfo
833784

834785
public abstract ValueGetter<TDst> GetGetter<TDst>();
835786

787+
internal abstract Delegate GetGetterDelegate();
788+
836789
protected MetadataInfo(string kind, ColumnType metadataType)
837790
{
838791
Contracts.AssertValueOrNull(metadataType);
@@ -951,6 +904,8 @@ public override ValueGetter<TDst> GetGetter<TDst>()
951904
throw Contracts.ExceptNotImpl("Type '{0}' is not yet supported.", typeT.FullName);
952905
}
953906

907+
internal override Delegate GetGetterDelegate() => Utils.MarshalInvoke(GetGetter<int>, MetadataType.RawType);
908+
954909
public class TElement
955910
{
956911
}

src/Microsoft.ML.Api/MapTransform.cs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
6-
using System.IO;
75
using Microsoft.ML.Runtime.Data;
86
using Microsoft.ML.Runtime.Internal.Utilities;
7+
using System;
8+
using System.IO;
99

1010
namespace Microsoft.ML.Runtime.Api
1111
{
@@ -25,7 +25,8 @@ internal sealed class MapTransform<TSrc, TDst> : LambdaTransformBase, ITransform
2525
{
2626
private const string RegistrationNameTemplate = "MapTransform<{0}, {1}>";
2727
private readonly Action<TSrc, TDst> _mapAction;
28-
private readonly MergedSchema _schema;
28+
private readonly InternalSchemaDefinition _addedSchema;
29+
private readonly ColumnBindings _bindings;
2930

3031
// Memorized input schema definition. Needed for re-apply.
3132
private readonly SchemaDefinition _inputSchemaDefinition;
@@ -67,7 +68,8 @@ public MapTransform(IHostEnvironment env, IDataView source, Action<TSrc, TDst> m
6768
? InternalSchemaDefinition.Create(typeof(TDst), SchemaDefinition.Direction.Write)
6869
: InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition);
6970

70-
_schema = MergedSchema.Create(Source.Schema, outSchema);
71+
_addedSchema = outSchema;
72+
_bindings = new ColumnBindings(Data.Schema.Create(Source.Schema), DataViewConstructionUtils.GetSchemaColumns(outSchema));
7173
}
7274

7375
/// <summary>
@@ -83,12 +85,13 @@ private MapTransform(IHostEnvironment env, MapTransform<TSrc, TDst> transform, I
8385
_mapAction = transform._mapAction;
8486
_typedSource = TypedCursorable<TSrc>.Create(Host, newSource, false, transform._inputSchemaDefinition);
8587

86-
_schema = MergedSchema.Create(newSource.Schema, transform._schema.AddedSchema);
88+
_addedSchema = transform._addedSchema;
89+
_bindings = new ColumnBindings(Data.Schema.Create(newSource.Schema), DataViewConstructionUtils.GetSchemaColumns(_addedSchema));
8790
}
8891

8992
public bool CanShuffle => Source.CanShuffle;
9093

91-
public ISchema Schema => _schema;
94+
public Schema Schema => _bindings.Schema;
9295

9396
public long? GetRowCount(bool lazy = true)
9497
{
@@ -104,7 +107,7 @@ public IRowCursor GetRowCursor(Func<int, bool> predicate, IRandom rand = null)
104107
if (DataViewUtils.TryCreateConsolidatingCursor(out curs, this, predicate, Host, rand))
105108
return curs;
106109

107-
var activeInputs = _schema.GetActiveInput(predicate);
110+
var activeInputs = _bindings.GetActiveInput(predicate);
108111
Func<int, bool> srcPredicate = c => activeInputs[c];
109112

110113
var input = _typedSource.GetCursor(srcPredicate, rand == null ? (int?)null : rand.Next());
@@ -116,7 +119,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun
116119
Host.CheckValue(predicate, nameof(predicate));
117120
Host.CheckValueOrNull(rand);
118121

119-
var activeInputs = _schema.GetActiveInput(predicate);
122+
var activeInputs = _bindings.GetActiveInput(predicate);
120123
Func<int, bool> srcPredicate = c => activeInputs[c];
121124

122125
var inputs = _typedSource.GetCursorSet(out consolidator, srcPredicate, n, rand);
@@ -140,7 +143,7 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
140143
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
141144
{
142145
Host.CheckValue(predicate, nameof(predicate));
143-
var activeInput = _schema.GetActiveInput(predicate);
146+
var activeInput = _bindings.GetActiveInput(predicate);
144147
Func<int, bool> srcPredicate =
145148
c =>
146149
{
@@ -150,7 +153,7 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
150153
return _typedSource.GetDependencies(srcPredicate);
151154
}
152155

153-
ISchema IRowToRowMapper.InputSchema => Source.Schema;
156+
Schema IRowToRowMapper.InputSchema => Source.Schema;
154157

155158
public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
156159
{
@@ -181,9 +184,9 @@ private IRow GetAppendedRow(Func<int, bool> active, TDst dst)
181184
// REVIEW: This is quite odd (for a cursor to create an IDataView). Consider cleaning up your
182185
// programming model for this. Note that you don't use the IDataView, only a cursor around a single row that
183186
// is owned by this cursor. Seems like that cursor implementation could be decoupled from any IDataView class.
184-
var appendedDataView = new DataViewConstructionUtils.SingleRowLoopDataView<TDst>(Host, _schema.AddedSchema);
187+
var appendedDataView = new DataViewConstructionUtils.SingleRowLoopDataView<TDst>(Host, _addedSchema);
185188
appendedDataView.SetCurrentRowObject(dst);
186-
return appendedDataView.GetRowCursor(i => active(_schema.MapIinfoToCol(i)));
189+
return appendedDataView.GetRowCursor(i => active(_bindings.AddedColumnIndices[i]));
187190
}
188191

189192
private sealed class Cursor : SynchronizedCursorBase<IRowCursor<TSrc>>, IRowCursor
@@ -210,7 +213,7 @@ public Cursor(IHost host, MapTransform<TSrc, TDst> owner, IRowCursor<TSrc> input
210213
CursorChannelAttribute.TrySetCursorChannel(host, _dst, Ch);
211214
}
212215

213-
public ISchema Schema => _row.Schema;
216+
public Schema Schema => _row.Schema;
214217

215218
public bool IsColumnActive(int col)
216219
{
@@ -257,7 +260,7 @@ private sealed class Row : IRow
257260

258261
public long Position => _input.Position;
259262

260-
public ISchema Schema { get; }
263+
public Schema Schema { get; }
261264

262265
public Row(IRowReadableAs<TSrc> input, MapTransform<TSrc, TDst> parent, Func<int, bool> active, TSrc src, TDst dst)
263266
{
@@ -281,7 +284,7 @@ public ValueGetter<TValue> GetGetter<TValue>(int col)
281284
public ValueGetter<TValue> GetGetterCore<TValue>(int col, Action checkIsGood)
282285
{
283286
bool isSrc;
284-
int index = _parent._schema.MapColumnIndex(out isSrc, col);
287+
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
285288
if (isSrc)
286289
return _input.GetGetter<TValue>(index);
287290

src/Microsoft.ML.Api/MergedSchema.cs

Lines changed: 0 additions & 58 deletions
This file was deleted.

src/Microsoft.ML.Api/PredictionEngine.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreM
149149
{
150150
}
151151

152-
private static Func<ISchema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
152+
private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
153153
{
154154
env.CheckValue(modelStream, nameof(modelStream));
155155
return schema =>
@@ -173,14 +173,14 @@ internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool i
173173
{
174174
}
175175

176-
private static Func<ISchema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
176+
private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
177177
{
178178
ectx.CheckValue(transformer, nameof(transformer));
179179
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
180180
return transformer.GetRowToRowMapper;
181181
}
182182

183-
private PredictionEngine(IHostEnvironment env, Func<ISchema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
183+
private PredictionEngine(IHostEnvironment env, Func<Schema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
184184
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
185185
{
186186
Contracts.CheckValue(env, nameof(env));

0 commit comments

Comments
 (0)