Skip to content

Commit d3cc1de

Browse files
committed
Merge pull request keras-team#2027 from sudeepraja/master
Added fix for training h5py dataset on Graph model
2 parents b0303f0 + 404a30d commit d3cc1de

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

keras/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,24 @@ def standardize_weights(y, sample_weight=None, class_weight=None,
111111
if sample_weight_mode != 'temporal':
112112
raise Exception('"sample_weight_mode '
113113
'should be None or "temporal".')
114-
if y.ndim < 3:
114+
if len(y.shape) < 3:
115115
raise Exception('Timestep-wise sample weighting (use of '
116116
'sample_weight_mode="temporal") is restricted to '
117117
'outputs that are at least 3D, i.e. that have '
118118
'a time dimension.')
119-
if sample_weight is not None and sample_weight.ndim != 2:
119+
if sample_weight is not None and len(sample_weight.shape) != 2:
120120
raise Exception('In order to use timestep-wise sample weighting, '
121121
'you should pass a 2D sample_weight array.')
122122
else:
123-
if sample_weight is not None and sample_weight.ndim != 1:
123+
if sample_weight is not None and len(sample_weight.shape) != 1:
124124
raise Exception('In order to use timestep-wise sample weights, '
125125
'you should specify sample_weight_mode="temporal" '
126126
'in compile(). If you just mean to use '
127127
'sample-wise weights, make sure your '
128128
'sample_weight array is 1D.')
129129
if sample_weight is not None:
130-
assert sample_weight.ndim <= y.ndim
131-
assert y.shape[:sample_weight.ndim] == sample_weight.shape
130+
assert len(sample_weight.shape) <= len(y.shape)
131+
assert y.shape[:len(sample_weight.shape)] == sample_weight.shape
132132
return sample_weight
133133
elif isinstance(class_weight, dict):
134134
if len(y.shape) > 2:

0 commit comments

Comments
 (0)