Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Another round of ISchema cleanup
  • Loading branch information
Pete Luferenko committed Oct 6, 2018
commit b27d7c46fc92ffa557ef01befe9d386924a50855
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/IDataView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Internal.Utilities;

namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// Interface for schema information.
/// Legacy interface for schema information.
/// Please avoid implementing this interface, use <see cref="Schema"/>.
/// </summary>
public interface ISchema
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private static void GetColumnArgs(ColumnType type,
/// <summary>
/// Create a schema shape out of the fully defined schema.
/// </summary>
public static SchemaShape Create(ISchema schema)
public static SchemaShape Create(Schema schema)
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List<Column>();
Expand Down
18 changes: 9 additions & 9 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] hea
/// The filter function is called for each column, passing in the schema and the column index, and returns
/// true if the column should be considered, false if the column should be skipped.
/// </summary>
public static uint GetMaxMetadataKind(this ISchema schema, out int colMax, string metadataKind, Func<ISchema, int, bool> filterFunc = null)
public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string metadataKind, Func<Schema, int, bool> filterFunc = null)
{
uint max = 0;
colMax = -1;
Expand All @@ -258,7 +258,7 @@ public static uint GetMaxMetadataKind(this ISchema schema, out int colMax, strin
/// Returns the set of column ids which match the value of specified metadata kind.
/// The metadata type should be a KeyType with raw type U4.
/// </summary>
public static IEnumerable<int> GetColumnSet(this ISchema schema, string metadataKind, uint value)
public static IEnumerable<int> GetColumnSet(this Schema schema, string metadataKind, uint value)
{
for (int col = 0; col < schema.ColumnCount; col++)
{
Expand All @@ -277,7 +277,7 @@ public static IEnumerable<int> GetColumnSet(this ISchema schema, string metadata
/// Returns the set of column ids which match the value of specified metadata kind.
/// The metadata type should be of type text.
/// </summary>
public static IEnumerable<int> GetColumnSet(this ISchema schema, string metadataKind, string value)
public static IEnumerable<int> GetColumnSet(this Schema schema, string metadataKind, string value)
{
for (int col = 0; col < schema.ColumnCount; col++)
{
Expand All @@ -298,7 +298,7 @@ public static IEnumerable<int> GetColumnSet(this ISchema schema, string metadata
/// * has a SlotNames metadata
/// * metadata type is VBuffer&lt;ReadOnlyMemory&lt;char&gt;&gt; of length N
/// </summary>
public static bool HasSlotNames(this ISchema schema, int col, int vectorSize)
public static bool HasSlotNames(this Schema schema, int col, int vectorSize)
{
if (vectorSize == 0)
return false;
Expand All @@ -323,7 +323,7 @@ public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Column
schema.Schema.GetMetadata(Kinds.SlotNames, list[0].Index, ref slotNames);
}

public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
public static bool HasKeyNames(this Schema schema, int col, int keyCount)
{
if (keyCount == 0)
return false;
Expand All @@ -345,7 +345,7 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
/// <param name="col">Which column in the schema to query</param>
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> metadata
/// set to the scalar value true</returns>
public static bool IsNormalized(this ISchema schema, int col)
public static bool IsNormalized(this Schema schema, int col)
{
Contracts.CheckValue(schema, nameof(schema));
var value = default(bool);
Expand Down Expand Up @@ -393,7 +393,7 @@ public static bool HasSlotNames(this SchemaShape.Column col)
/// <param name="col">The column</param>
/// <param name="value">The value to return, if successful</param>
/// <returns>True if the metadata of the right type exists, false otherwise</returns>
public static bool TryGetMetadata<T>(this ISchema schema, PrimitiveType type, string kind, int col, ref T value)
public static bool TryGetMetadata<T>(this Schema schema, PrimitiveType type, string kind, int col, ref T value)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckValue(type, nameof(type));
Expand All @@ -408,7 +408,7 @@ public static bool TryGetMetadata<T>(this ISchema schema, PrimitiveType type, st
/// <summary>
/// Return whether the given column index is hidden in the given schema.
/// </summary>
public static bool IsHidden(this ISchema schema, int col)
public static bool IsHidden(this Schema schema, int col)
{
Contracts.CheckValue(schema, nameof(schema));
string name = schema.GetColumnName(col);
Expand All @@ -426,7 +426,7 @@ public static bool IsHidden(this ISchema schema, int col)
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
/// </summary>
public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, out int[] categoricalFeatures)
public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(colIndex >= 0, nameof(colIndex));
Expand Down
12 changes: 9 additions & 3 deletions src/Microsoft.ML.Core/Data/Schema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace Microsoft.ML.Runtime.Data
{
#pragma warning disable CS0618 // Type or member is obsolete
public sealed class Schema : ISchema
Copy link
Contributor

@TomFinley TomFinley Oct 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Schema [](start = 24, length = 6)

Are you planning on writing actual documentation for this eventually? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do right now


In reply to: 224531123 [](ancestors = 224531123)

#pragma warning restore CS0618 // Type or member is obsolete
{
private readonly Column[] _columns;
private readonly Dictionary<string, int> _nameMap;
Expand Down Expand Up @@ -182,10 +184,12 @@ public Schema(IEnumerable<Column> columns)

public IEnumerable<(int index, Column column)> GetColumns() => _nameMap.Values.Select(idx => (idx, _columns[idx]));

/// <summary>
/// Manufacture an instance of <see cref="Schema"/> out of any <see cref="ISchema"/>.
/// </summary>
#pragma warning disable CS0618 // Type or member is obsolete
/// <summary>
/// Manufacture an instance of <see cref="Schema"/> out of any <see cref="ISchema"/>.
/// </summary>
Copy link
Member

@sfilipi sfilipi Oct 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indenting #Resolved

public static Schema Create(ISchema inputSchema)
#pragma warning restore CS0618 // Type or member is obsolete
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));

Expand All @@ -207,7 +211,9 @@ public static Schema Create(ISchema inputSchema)
return new Schema(columns);
}

#pragma warning disable CS0618 // Type or member is obsolete
private static Delegate GetMetadataGetterDelegate<TValue>(ISchema schema, int col, string kind)
#pragma warning restore CS0618 // Type or member is obsolete
{
// REVIEW: We are facing a choice here: cache 'value' and get rid of 'schema' reference altogether,
// or retain the reference but be more memory efficient. This code should not stick around for too long
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName
/// and the column name was explicitly specified. If the column is not found
/// and the column name was not explicitly specified, it returns null.
/// </summary>
public static string FindColumnOrNull(IExceptionContext ectx, ISchema schema, Optional<string> value)
public static string FindColumnOrNull(IExceptionContext ectx, Schema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ private sealed class FoldHelper
public struct FoldResult
{
public readonly Dictionary<string, IDataView> Metrics;
public readonly ISchema ScoreSchema;
public readonly Schema ScoreSchema;
public readonly RoleMappedData PerInstanceResults;
public readonly RoleMappedSchema TrainSchema;

public FoldResult(Dictionary<string, IDataView> metrics, ISchema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema)
public FoldResult(Dictionary<string, IDataView> metrics, Schema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema)
{
Metrics = metrics;
ScoreSchema = scoreSchema;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/ScoreCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ private void RunCore(IChannel ch)
/// Whether a column should be added, assuming it's not hidden
/// (i.e.: this doesn't check for hidden
/// </summary>
private bool ShouldAddColumn(ISchema schema, int i, uint scoreSet, bool outputNamesAndLabels)
private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNamesAndLabels)
{
uint scoreSetId = 0;
if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env,
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, newView, input.Data), OutputData = newView };
}

private static bool ShouldAddColumn(ISchema schema, int i, string[] extraColumns, uint scoreSet)
private static bool ShouldAddColumn(Schema schema, int i, string[] extraColumns, uint scoreSet)
{
uint scoreSetId = 0;
if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
Expand Down
14 changes: 7 additions & 7 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static Dictionary<string, Func<IHostEnvironment, IMamlEvaluator>> Instanc
}
}

public static IMamlEvaluator GetEvaluator(IHostEnvironment env, ISchema schema)
public static IMamlEvaluator GetEvaluator(IHostEnvironment env, Schema schema)
{
Contracts.CheckValueOrNull(env);
ReadOnlyMemory<char> tmp = default;
Expand Down Expand Up @@ -90,7 +90,7 @@ private static bool CheckScoreColumnKindIsKnown(ISchema schema, int col)
}

// Lambda used as validator/filter in calls to GetMaxMetadataKind.
private static bool CheckScoreColumnKind(ISchema schema, int col)
private static bool CheckScoreColumnKind(Schema schema, int col)
{
var columnType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, col);
return columnType != null && columnType.IsText;
Expand All @@ -101,7 +101,7 @@ private static bool CheckScoreColumnKind(ISchema schema, int col)
/// most recent score set of the given kind. If there is no such score set and defName is specifed it
/// uses defName. Otherwise, it throws.
/// </summary>
public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, ISchema schema, string name, string argName, string kind,
public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName, string kind,
string valueKind = MetadataUtils.Const.ScoreValueKind.Score, string defName = null)
{
Contracts.CheckValueOrNull(ectx);
Expand Down Expand Up @@ -155,7 +155,7 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, ISchema sche
/// Otherwise, if colScore is part of a score set, this looks in the score set for a column
/// with the given valueKind. If none is found, it returns null.
/// </summary>
public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, ISchema schema, string name, string argName,
public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName,
int colScore, string valueKind, Func<ColumnType, bool> testType)
{
Contracts.CheckValueOrNull(ectx);
Expand Down Expand Up @@ -939,7 +939,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string
return AppendRowsDataView.Create(env, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray());
}

private static IEnumerable<int> FindHiddenColumns(ISchema schema, string colName)
private static IEnumerable<int> FindHiddenColumns(Schema schema, string colName)
{
for (int i = 0; i < schema.ColumnCount; i++)
{
Expand Down Expand Up @@ -985,7 +985,7 @@ private static IDataView AddVarLengthColumn<TSrc>(IHostEnvironment env, IDataVie
(ref VBuffer<TSrc> src, ref VBuffer<TSrc> dst) => src.CopyTo(ref dst));
}

private static List<string> GetMetricNames(IChannel ch, ISchema schema, IRow row, Func<int, bool> ignoreCol,
private static List<string> GetMetricNames(IChannel ch, Schema schema, IRow row, Func<int, bool> ignoreCol,
ValueGetter<double>[] getters, ValueGetter<VBuffer<double>>[] vBufferGetters)
{
ch.AssertValue(schema);
Expand Down Expand Up @@ -1205,7 +1205,7 @@ private static void UpdateSums(int isWeightedCol, int stratCol, int stratVal, Ag
Contracts.Assert(iMetric == metricNames.Count);
}

internal static IDataView GetAverageToDataView(IHostEnvironment env, ISchema schema, AggregatedMetric[] agg, AggregatedMetric[] weightedAgg,
internal static IDataView GetAverageToDataView(IHostEnvironment env, Schema schema, AggregatedMetric[] agg, AggregatedMetric[] weightedAgg,
int numFolds, int stratCol, int stratVal, int isWeightedCol, int foldCol, bool hasStdev, List<string> nonAveragedCols = null)
{
Contracts.AssertValue(env);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,14 @@ private static VersionInfo GetVersionInfo()
private readonly Schema.MetadataRow _labelMetadata;
private readonly Schema.MetadataRow _scoreMetadata;

public MultiOutputRegressionPerInstanceEvaluator(IHostEnvironment env, ISchema schema, string scoreCol,
public MultiOutputRegressionPerInstanceEvaluator(IHostEnvironment env, Schema schema, string scoreCol,
string labelCol)
: base(env, schema, scoreCol, labelCol)
{
CheckInputColumnTypes(schema, out _labelType, out _scoreType, out _labelMetadata, out _scoreMetadata);
}

private MultiOutputRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, ISchema schema)
private MultiOutputRegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, Schema schema)
: base(env, ctx, schema)
{
CheckInputColumnTypes(schema, out _labelType, out _scoreType, out _labelMetadata, out _scoreMetadata);
Expand All @@ -426,7 +426,7 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema);
return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, Schema.Create(schema));
}

public override void Save(ModelSaveContext ctx)
Expand Down Expand Up @@ -544,7 +544,7 @@ public override Delegate[] CreateGetters(IRow input, Func<int, bool> activeCols,
return getters;
}

private void CheckInputColumnTypes(ISchema schema, out ColumnType labelType, out ColumnType scoreType,
private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out ColumnType scoreType,
out Schema.MetadataRow labelMetadata, out Schema.MetadataRow scoreMetadata)
{
Host.AssertNonEmpty(ScoreCol);
Expand Down Expand Up @@ -575,7 +575,7 @@ private void CheckInputColumnTypes(ISchema schema, out ColumnType labelType, out
scoreMetadata = builder.GetMetadataRow();
}

private ValueGetter<uint> GetScoreColumnSetId(ISchema schema)
private ValueGetter<uint> GetScoreColumnSetId(Schema schema)
{
int c;
var max = schema.GetMaxMetadataKind(out c, MetadataUtils.Kinds.ScoreColumnSetId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ private static VersionInfo GetVersionInfo()
private readonly ReadOnlyMemory<char>[] _classNames;
private readonly ColumnType[] _types;

public MultiClassPerInstanceEvaluator(IHostEnvironment env, ISchema schema, ColumnInfo scoreInfo, string labelCol)
public MultiClassPerInstanceEvaluator(IHostEnvironment env, Schema schema, ColumnInfo scoreInfo, string labelCol)
: base(env, schema, Contracts.CheckRef(scoreInfo, nameof(scoreInfo)).Name, labelCol)
{
CheckInputColumnTypes(schema);
Expand Down
4 changes: 0 additions & 4 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,6 @@ private sealed class Bindings : BindingsBase
private readonly int _truncationLevel;
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> _slotNamesGetter;

public override Schema AsSchema { get; }

public Bindings(IExceptionContext ectx, ISchema input, bool user, string labelCol, string scoreCol, string groupCol,
int truncationLevel)
: base(ectx, input, labelCol, scoreCol, groupCol, user, Ndcg, Dcg, MaxDcg)
Expand All @@ -669,8 +667,6 @@ public Bindings(IExceptionContext ectx, ISchema input, bool user, string labelCo
_outputType = new VectorType(NumberType.R8, _truncationLevel);
_slotNamesType = new VectorType(TextType.Instance, _truncationLevel);
_slotNamesGetter = SlotNamesGetter;

AsSchema = Schema.Create(this);
}

protected override ColumnType GetColumnTypeCore(int iinfo)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public sealed class BoundPfaContext
private readonly bool _allowSet;
private readonly IHost _host;

public BoundPfaContext(IHostEnvironment env, ISchema inputSchema, HashSet<string> toDrop, bool allowSet)
public BoundPfaContext(IHostEnvironment env, Schema inputSchema, HashSet<string> toDrop, bool allowSet)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(BoundPfaContext));
Expand All @@ -54,7 +54,7 @@ public BoundPfaContext(IHostEnvironment env, ISchema inputSchema, HashSet<string
SetInput(inputSchema, toDrop);
}

private void SetInput(ISchema schema, HashSet<string> toDrop)
private void SetInput(Schema schema, HashSet<string> toDrop)
{
var recordType = new JObject();
recordType["type"] = "record";
Expand Down
Loading