@@ -111,24 +111,24 @@ def standardize_weights(y, sample_weight=None, class_weight=None,
111111 if sample_weight_mode != 'temporal' :
112112 raise Exception ('"sample_weight_mode '
113113 'should be None or "temporal".' )
114- if y . ndim < 3 :
114+ if len ( y . shape ) < 3 :
115115 raise Exception ('Timestep-wise sample weighting (use of '
116116 'sample_weight_mode="temporal") is restricted to '
117117 'outputs that are at least 3D, i.e. that have '
118118 'a time dimension.' )
119- if sample_weight is not None and sample_weight .ndim != 2 :
119+ if sample_weight is not None and len ( sample_weight .shape ) != 2 :
120120 raise Exception ('In order to use timestep-wise sample weighting, '
121121 'you should pass a 2D sample_weight array.' )
122122 else :
123- if sample_weight is not None and sample_weight .ndim != 1 :
123+ if sample_weight is not None and len ( sample_weight .shape ) != 1 :
124124 raise Exception ('In order to use timestep-wise sample weights, '
125125 'you should specify sample_weight_mode="temporal" '
126126 'in compile(). If you just mean to use '
127127 'sample-wise weights, make sure your '
128128 'sample_weight array is 1D.' )
129129 if sample_weight is not None :
130- assert sample_weight .ndim <= y . ndim
131- assert y .shape [:sample_weight .ndim ] == sample_weight .shape
130+ assert len ( sample_weight .shape ) <= len ( y . shape )
131+ assert y .shape [:len ( sample_weight .shape ) ] == sample_weight .shape
132132 return sample_weight
133133 elif isinstance (class_weight , dict ):
134134 if len (y .shape ) > 2 :
0 commit comments