Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: type error, code refactore
  • Loading branch information
Jeel Gajera committed Oct 15, 2023
commit 12cb7a27913c2f8315000c1d65cee2e2eca61b6f
19 changes: 7 additions & 12 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from typing import List, Tuple


def load_data() -> List[List[str]]:
def load_data() -> list[list[str]]:
"""
Returns a sample transaction dataset.

Expand All @@ -35,7 +32,7 @@ def load_data() -> List[List[str]]:
]
return data

def generate_candidates(itemset: List[str], length: int) -> List[List[str]]:
def generate_candidates(itemset: list[str], length: int):
"""
Generates candidate itemsets of size k from the given itemsets.

Expand All @@ -47,8 +44,6 @@ def generate_candidates(itemset: List[str], length: int) -> List[List[str]]:
>>> generate_candidates(itemsets, 3)
[['milk', 'bread', 'butter']]
"""

def generate_candidates(itemset: List[str], length: int):
candidates = []
for i in range(len(itemset)):
for j in range(i + 1, len(itemset)):
Expand All @@ -63,8 +58,8 @@ def generate_candidates(itemset: List[str], length: int):


def prune(
itemset: List[str], candidates: List[List[str]], length: int
) -> List[List[str]]:
itemset: list[str], candidates: list[list[str]], length: int
) -> list[list[str]]:
# Prune candidate itemsets
"""
The goal of pruning is to filter out candidate itemsets that are not frequent. This is done by checking if all the (k-1) subsets of a candidate itemset are present in the frequent itemsets of the previous iteration (valid subsequences of the frequent itemsets from the previous iteration).
Expand Down Expand Up @@ -93,7 +88,7 @@ def prune(
return pruned


def apriori(data: List[List[str]], min_support: int) -> List[Tuple[List[str], int]]:
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
"""
Returns a list of frequent itemsets and their support counts.

Expand Down Expand Up @@ -138,11 +133,11 @@ def apriori(data: List[List[str]], min_support: int) -> List[Tuple[List[str], in
Apriori algorithm for finding frequent itemsets.

Args:
data (List[List[str]]): A list of transactions, where each transaction is a list of items.
data (list[list[str]]): A list of transactions, where each transaction is a list of items.
min_support (int): The minimum support threshold for frequent itemsets.

Returns:
List[Tuple[List[str], int]]: A list of frequent itemsets along with their support counts.
list[Tuple[list[str], int]]: A list of frequent itemsets along with their support counts.

Example:
>>> data = [["milk", "bread"], ["milk", "butter"], ["milk", "bread", "nuts"]]
Expand Down