Skip to content

Commit 3ce798f

Browse files
authored
Added input vector shape compatibility check and test
1 parent 3e281a9 commit 3ce798f

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

Pytest/linear_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def train_linear_model(X,y,
1717
assert test_frac < 1.0, "Test set fraction must be between 0.0 and 1.0"
1818
assert test_frac > 0, "Test set fraction must be between 0.0 and 1.0"
1919
assert isinstance(filename, str), "Filename must be a string"
20+
assert X.shape[0] == y.shape[0], "Row numbers of X and y data must be identical"
2021

2122
# Shaping
2223
if len(X.shape) == 1:

Pytest/test_linear_model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Pytest module for testing linear regression model function
2+
# Dr. Tirthajyoti Sarkar, Fremont, CA
3+
4+
15
from joblib import load, dump
26
import numpy as np
37
from linear_model import train_linear_model
@@ -164,26 +168,31 @@ def test_wrong_input_raises_assertion():
164168
#=================================
165169
# TEST SUITES
166170
#=================================
167-
# Check that it handles the case of: X is a string
171+
# Test that it handles the case of: X is a string
168172
msg = train_linear_model('X',y)
169173
assert isinstance(msg, AssertionError)
170174
assert msg.args[0] == "X must be a Numpy array"
171-
# Check that it handles the case of: y is a string
175+
# Test that it handles the case of: y is a string
172176
msg = train_linear_model(X,'y')
173177
assert isinstance(msg, AssertionError)
174178
assert msg.args[0] == "y must be a Numpy array"
175-
# Check that it handles the case of: test_frac is a string
179+
# Test that it handles the case of: test_frac is a string
176180
msg = train_linear_model(X,y, test_frac='0.2')
177181
assert isinstance(msg, AssertionError)
178182
assert msg.args[0] == "Test set fraction must be a floating point number"
179-
# Check that it handles the case of: test_frac is within 0.0 and 1.0
183+
# Test that it handles the case of: test_frac is within 0.0 and 1.0
180184
msg = train_linear_model(X,y, test_frac=-0.2)
181185
assert isinstance(msg, AssertionError)
182186
assert msg.args[0] == "Test set fraction must be between 0.0 and 1.0"
183187
msg = train_linear_model(X,y, test_frac=1.2)
184188
assert isinstance(msg, AssertionError)
185189
assert msg.args[0] == "Test set fraction must be between 0.0 and 1.0"
186-
# Check that it handles the case of: filename for model save a string
190+
# Test that it handles the case of: filename for model save a string
187191
msg = train_linear_model(X,y, filename = 2.0)
188192
assert isinstance(msg, AssertionError)
189-
assert msg.args[0] == "Filename must be a string"
193+
assert msg.args[0] == "Filename must be a string"
194+
# Test that function is checking input vector shape compatibility
195+
X = X.reshape(10,10)
196+
msg = train_linear_model(X,y, filename='testing')
197+
assert isinstance(msg, AssertionError)
198+
assert msg.args[0] == "Row numbers of X and y data must be identical"

0 commit comments

Comments
 (0)