2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
+ using Microsoft . ML . Runtime . Data ;
6
+ using Microsoft . ML . Runtime . Internal . Utilities ;
7
+ using Microsoft . ML . Runtime . Model ;
5
8
using System ;
6
9
using System . Collections . Generic ;
7
10
using System . IO ;
8
11
using System . Linq ;
9
12
using System . Reflection ;
10
- using Microsoft . ML . Runtime . Data ;
11
- using Microsoft . ML . Runtime . Internal . Utilities ;
12
- using Microsoft . ML . Runtime . Model ;
13
13
14
14
namespace Microsoft . ML . Runtime . Api
15
15
{
@@ -83,7 +83,7 @@ public sealed class InputRow<TRow> : InputRowBase<TRow>, IRowBackedBy<TRow>
83
83
public override long Position => _position ;
84
84
85
85
public InputRow ( IHostEnvironment env , InternalSchemaDefinition schemaDef )
86
- : base ( env , new SchemaProxy ( schemaDef ) , schemaDef , MakePeeks ( schemaDef ) , c => true )
86
+ : base ( env , new Schema ( GetSchemaColumns ( schemaDef ) ) , schemaDef , MakePeeks ( schemaDef ) , c => true )
87
87
{
88
88
_position = - 1 ;
89
89
}
@@ -136,11 +136,11 @@ public abstract class InputRowBase<TRow> : IRow
136
136
137
137
public long Batch => 0 ;
138
138
139
- public ISchema Schema { get ; }
139
+ public Schema Schema { get ; }
140
140
141
141
public abstract long Position { get ; }
142
142
143
- public InputRowBase ( IHostEnvironment env , ISchema schema , InternalSchemaDefinition schemaDef , Delegate [ ] peeks , Func < int , bool > predicate )
143
+ public InputRowBase ( IHostEnvironment env , Schema schema , InternalSchemaDefinition schemaDef , Delegate [ ] peeks , Func < int , bool > predicate )
144
144
{
145
145
Contracts . AssertValue ( env ) ;
146
146
Host = env . Register ( "Row" ) ;
@@ -326,27 +326,30 @@ public abstract class DataViewBase<TRow> : IDataView
326
326
{
327
327
protected readonly IHost Host ;
328
328
329
- private readonly SchemaProxy _schema ;
329
+ private readonly Schema _schema ;
330
+ private readonly InternalSchemaDefinition _schemaDefn ;
330
331
331
332
// The array of generated methods that extract the fields of the current row object.
332
333
private readonly Delegate [ ] _peeks ;
333
334
334
335
public abstract bool CanShuffle { get ; }
335
336
336
- public ISchema Schema => _schema ;
337
+ public Schema Schema => _schema ;
337
338
338
339
protected DataViewBase ( IHostEnvironment env , string name , InternalSchemaDefinition schemaDefn )
339
340
{
340
341
Contracts . AssertValue ( env ) ;
341
342
env . AssertNonWhiteSpace ( name ) ;
342
343
Host = env . Register ( name ) ;
343
344
Host . AssertValue ( schemaDefn ) ;
344
- _schema = new SchemaProxy ( schemaDefn ) ;
345
- int n = _schema . SchemaDefn . Columns . Length ;
345
+
346
+ _schemaDefn = schemaDefn ;
347
+ _schema = new Schema ( GetSchemaColumns ( schemaDefn ) ) ;
348
+ int n = schemaDefn . Columns . Length ;
346
349
_peeks = new Delegate [ n ] ;
347
350
for ( var i = 0 ; i < n ; i ++ )
348
351
{
349
- var currentColumn = _schema . SchemaDefn . Columns [ i ] ;
352
+ var currentColumn = schemaDefn . Columns [ i ] ;
350
353
_peeks [ i ] = currentColumn . IsComputed
351
354
? currentColumn . Generator
352
355
: ApiUtils . GeneratePeek < DataViewBase < TRow > , TRow > ( currentColumn ) ;
@@ -381,7 +384,7 @@ public abstract class DataViewCursorBase : InputRowBase<TRow>, IRowCursor
381
384
382
385
protected DataViewCursorBase ( IHostEnvironment env , DataViewBase < TRow > dataView ,
383
386
Func < int , bool > predicate )
384
- : base ( env , dataView . Schema , dataView . _schema . SchemaDefn , dataView . _peeks , predicate )
387
+ : base ( env , dataView . Schema , dataView . _schemaDefn , dataView . _peeks , predicate )
385
388
{
386
389
Contracts . AssertValue ( env ) ;
387
390
Ch = env . Start ( "Cursor" ) ;
@@ -747,72 +750,20 @@ protected override bool MoveManyCore(long count)
747
750
}
748
751
}
749
752
750
- private sealed class SchemaProxy : ISchema
753
+ internal static Schema . Column [ ] GetSchemaColumns ( InternalSchemaDefinition schemaDefn )
751
754
{
752
- public readonly InternalSchemaDefinition SchemaDefn ;
753
-
754
- public SchemaProxy ( InternalSchemaDefinition schemaDefn )
755
- {
756
- SchemaDefn = schemaDefn ;
757
- }
758
-
759
- public int ColumnCount
760
- {
761
- get { return SchemaDefn . Columns . Length ; }
762
- }
763
-
764
- public bool TryGetColumnIndex ( string name , out int col )
765
- {
766
- col = Array . FindIndex ( SchemaDefn . Columns , c => c . ColumnName == name ) ;
767
- return col >= 0 ;
768
- }
769
-
770
- public string GetColumnName ( int col )
755
+ Contracts . AssertValue ( schemaDefn ) ;
756
+ var columns = new Schema . Column [ schemaDefn . Columns . Length ] ;
757
+ for ( int i = 0 ; i < columns . Length ; i ++ )
771
758
{
772
- CheckColumnInRange ( col ) ;
773
- return SchemaDefn . Columns [ col ] . ColumnName ;
759
+ var col = schemaDefn . Columns [ i ] ;
760
+ var meta = new Schema . Metadata . Builder ( ) ;
761
+ foreach ( var kvp in col . Metadata )
762
+ meta . Add ( new Schema . Column ( kvp . Value . Kind , kvp . Value . MetadataType , null ) , kvp . Value . GetGetterDelegate ( ) ) ;
763
+ columns [ i ] = new Schema . Column ( col . ColumnName , col . ColumnType , meta . GetMetadata ( ) ) ;
774
764
}
775
765
776
- public ColumnType GetColumnType ( int col )
777
- {
778
- CheckColumnInRange ( col ) ;
779
- return SchemaDefn . Columns [ col ] . ColumnType ;
780
- }
781
-
782
- public IEnumerable < KeyValuePair < string , ColumnType > > GetMetadataTypes ( int col )
783
- {
784
- CheckColumnInRange ( col ) ;
785
- var columnMetadata = SchemaDefn . Columns [ col ] . Metadata ;
786
- if ( columnMetadata == null )
787
- yield break ;
788
- foreach ( var kvp in columnMetadata . Select ( x => new KeyValuePair < string , ColumnType > ( x . Key , x . Value . MetadataType ) ) )
789
- yield return kvp ;
790
- }
791
-
792
- public ColumnType GetMetadataTypeOrNull ( string kind , int col )
793
- {
794
- if ( string . IsNullOrEmpty ( kind ) )
795
- throw MetadataUtils . ExceptGetMetadata ( ) ;
796
- CheckColumnInRange ( col ) ;
797
- var column = SchemaDefn . Columns [ col ] ;
798
- return column . Metadata . ContainsKey ( kind ) ? column . Metadata [ kind ] . MetadataType : null ;
799
- }
800
-
801
- public void GetMetadata < TValue > ( string kind , int col , ref TValue value )
802
- {
803
- var metadataType = GetMetadataTypeOrNull ( kind , col ) ;
804
- if ( metadataType == null )
805
- throw MetadataUtils . ExceptGetMetadata ( ) ;
806
-
807
- var metadata = SchemaDefn . Columns [ col ] . Metadata [ kind ] ;
808
- metadata . GetGetter < TValue > ( ) ( ref value ) ;
809
- }
810
-
811
- private void CheckColumnInRange ( int columnIndex )
812
- {
813
- if ( columnIndex < 0 || columnIndex >= SchemaDefn . Columns . Length )
814
- throw Contracts . Except ( "Column index must be between 0 and {0}" , SchemaDefn . Columns . Length ) ;
815
- }
766
+ return columns ;
816
767
}
817
768
}
818
769
@@ -833,6 +784,8 @@ public abstract partial class MetadataInfo
833
784
834
785
public abstract ValueGetter < TDst > GetGetter < TDst > ( ) ;
835
786
787
+ internal abstract Delegate GetGetterDelegate ( ) ;
788
+
836
789
protected MetadataInfo ( string kind , ColumnType metadataType )
837
790
{
838
791
Contracts . AssertValueOrNull ( metadataType ) ;
@@ -951,6 +904,8 @@ public override ValueGetter<TDst> GetGetter<TDst>()
951
904
throw Contracts . ExceptNotImpl ( "Type '{0}' is not yet supported." , typeT . FullName ) ;
952
905
}
953
906
907
+ internal override Delegate GetGetterDelegate ( ) => Utils . MarshalInvoke ( GetGetter < int > , MetadataType . RawType ) ;
908
+
954
909
public class TElement
955
910
{
956
911
}
0 commit comments