Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 15, 2023
commit 2544fa8962df0575e39ef7f095ce6788b706b1c4
12 changes: 9 additions & 3 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""


def load_data() -> list[list[str]]:
"""
Returns a sample transaction dataset.
Expand All @@ -32,6 +33,7 @@ def load_data() -> list[list[str]]:
]
return data


def generate_candidates(itemset: list[str], length: int):
"""
Generates candidate itemsets of size k from the given itemsets.
Expand All @@ -57,7 +59,9 @@ def generate_candidates(itemset: list[str], length: int):
return candidates


def prune(itemset: list[str], candidates: list[list[str]], length: int ) -> list[list[str]]:
def prune(
itemset: list[str], candidates: list[list[str]], length: int
) -> list[list[str]]:
# Prune candidate itemsets
"""
The goal of pruning is to filter out candidate itemsets that are not frequent. This is done by checking if all the (k-1) subsets of a candidate itemset are present in the frequent itemsets of the previous iteration (valid subsequences of the frequent itemsets from the previous iteration).
Expand Down Expand Up @@ -107,7 +111,9 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
counts = [0] * len(itemset)
for i, transaction in enumerate(data):
for j, item in enumerate(itemset):
if item.issubset(transaction): # using set for faster membership checking
if item.issubset(
transaction
): # using set for faster membership checking
counts[j] += 1

# Prune infrequent itemsets
Expand Down Expand Up @@ -152,4 +158,4 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
min_support = 2 # user-defined threshold or minimum support level
frequent_itemsets = apriori(data, min_support)
for itemset, support in frequent_itemsets:
print(f"{itemset}: {support}")
print(f"{itemset}: {support}")