Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor data_safety_checker to return instead of print
  • Loading branch information
tianyizheng02 committed Oct 22, 2022
commit 63afdaf754b1f9ff9c0309867cd1432b6338ba40
17 changes: 10 additions & 7 deletions machine_learning/forecasting/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def interquartile_range_checker(train_user: list) -> float:
return low_lim


def data_safety_checker(list_vote: list, actual_result: float) -> None:
def data_safety_checker(list_vote: list, actual_result: float) -> bool:
"""
Used to review all the votes (list result prediction)
and compare it to the actual result.
input : list of predictions
output : print whether it's safe or not
>>> data_safety_checker([2,3,4],5.0)
Today's data is not safe.
>>> data_safety_checker([2, 3, 4], 5.0)
False
"""
safe = 0
not_safe = 0
Expand All @@ -107,10 +107,10 @@ def data_safety_checker(list_vote: list, actual_result: float) -> None:
safe = not_safe + 1
else:
if abs(abs(i) - abs(actual_result)) <= 0.1:
safe = safe + 1
safe += 1
else:
not_safe = not_safe + 1
print(f"Today's data is {'not ' if safe <= not_safe else ''}safe.")
not_safe += 1
return safe > not_safe


if __name__ == "__main__":
Expand Down Expand Up @@ -155,4 +155,7 @@ def data_safety_checker(list_vote: list, actual_result: float) -> None:
res_vote.append(support_vector_regressor(x_train, x_test, trn_user))

# check the safety of today's data
data_safety_checker(res_vote, tst_user)
if data_safety_checker(res_vote, tst_user):
print("Today's data is safe.")
else:
print("Today's data is not safe.")