Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
change constraints tests to new constraint api
  • Loading branch information
phreeza authored and tleeuwenburg committed Jul 3, 2015
commit acef25270349905926730a47cbdf4380ece5ad94
14 changes: 10 additions & 4 deletions tests/auto/keras/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ def test_maxnorm(self):
def test_nonneg(self):
from keras.constraints import nonneg

normed = nonneg(self.example_array)
nonneg_instance = nonneg()

normed = nonneg_instance(self.example_array)
assert (np.all(np.min(normed.eval(), axis=1) == 0.))

def test_identity(self):
from keras.constraints import identity

normed = identity(self.example_array)
identity_instance = identity()

normed = identity_instance(self.example_array)
assert (np.all(normed == self.example_array))

def test_identity_oddballs(self):
Expand All @@ -45,14 +49,16 @@ def test_identity_oddballs(self):
but it should in the current implementation.
"""
from keras.constraints import identity
identity_instance = identity()

oddball_examples = ["Hello", [1], -1, None]
assert(oddball_examples == identity(oddball_examples))
assert(oddball_examples == identity_instance(oddball_examples))

def test_unitnorm(self):
from keras.constraints import unitnorm
unitnorm_instance = unitnorm()

normalized = unitnorm(self.example_array)
normalized = unitnorm_instance(self.example_array)

norm_of_normalized = np.sqrt(np.sum(normalized.eval()**2, axis=1))
difference = norm_of_normalized - 1. #in the unit norm constraint, it should be equal to 1.
Expand Down