Skip to content
Prev Previous commit
Next Next commit
split if/else blocks into separate instance methods
  • Loading branch information
jeffreydominic committed Feb 16, 2023
commit 03eaf4384b6574573ef55082545161dbe4dfda9b
354 changes: 187 additions & 167 deletions medsegpy/data/transforms/transform_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,75 +249,84 @@ def __init__(
# Precompute masks
self.precomputed_masks = []
if sampling_pattern == "poisson":
if not (min_height == max_height == min_width == max_width):
raise ValueError("Only square patches are allowed if sampling pattern is 'poisson'")

hole_size = min_width
# Check max_perc_area_to_remove
# - Assuming a hexagonal packing of circles will output the
# most number of samples when using Poisson disc sampling.
# - From https://mathworld.wolfram.com/CirclePacking.html:
# Max packing density when using hexagonal packing is
# pi / (2 * sqrt(3))
max_pos_area = (img_shape[0] * img_shape[1]) * (np.pi / (2 * np.sqrt(3)))
max_num_patches = max_pos_area // ((np.pi / 2) * (hole_size ** 2))
max_pos_perc_area = np.round(
(max_num_patches * (hole_size ** 2)) / (img_shape[0] * img_shape[1]), decimals=3
self.create_poisson_disc_masks()
else:
self.create_random_masks()

def create_poisson_disc_masks(self):
"""Precomputes masks using poisson-disc sampling"""

if not (self.min_height == self.max_height == self.min_width == self.max_width):
raise ValueError("Only square patches are allowed if sampling pattern is 'poisson'")

hole_size = self.min_width
# Check max_perc_area_to_remove
# - Assuming a hexagonal packing of circles will output the
# most number of samples when using Poisson disc sampling.
# - From https://mathworld.wolfram.com/CirclePacking.html:
# Max packing density when using hexagonal packing is
# pi / (2 * sqrt(3))
max_pos_area = (self.img_shape[0] * self.img_shape[1]) * (np.pi / (2 * np.sqrt(3)))
max_num_patches = max_pos_area // ((np.pi / 2) * (hole_size ** 2))
max_pos_perc_area = np.round(
(max_num_patches * (hole_size ** 2)) / (self.img_shape[0] * self.img_shape[1]),
decimals=3,
)
if self.max_perc_area_to_remove >= max_pos_perc_area:
raise ValueError(
f"Value of 'max_perc_area_to_remove' "
f"(= {self.max_perc_area_to_remove}) is too large. "
f"An overestimate of the maximum possible % area that can "
f"be corrupted given {hole_size} x {hole_size} patches "
f"is: {max_pos_perc_area}. Make sure "
f"'max_perc_area_to_remove' is less than this value."
)
if max_perc_area_to_remove >= max_pos_perc_area:
raise ValueError(
f"Value of 'max_perc_area_to_remove' "
f"(= {max_perc_area_to_remove}) is too large. "
f"An overestimate of the maximum possible % area that can "
f"be corrupted given {hole_size} x {hole_size} patches "
f"is: {max_pos_perc_area}. Make sure "
f"'max_perc_area_to_remove' is less than this value."
)

# If `sampling_pattern` is "poisson", precompute
# `num_precompute` masks
num_samples = ((img_shape[0] * img_shape[1]) * max_perc_area_to_remove) // (
hole_size ** 2
# If `sampling_pattern` is "poisson", precompute
# `num_precompute` masks
num_samples = ((self.img_shape[0] * self.img_shape[1]) * self.max_perc_area_to_remove) // (
hole_size ** 2
)
logger.info("Precomputing masks...")
for _ in tqdm.tqdm(range(self.num_precompute)):
_, patch_mask = generate_poisson_disc_mask(
(self.img_shape[0], self.img_shape[1]),
min_distance=hole_size * np.sqrt(2),
num_samples=num_samples,
patch_size=hole_size,
k=10,
)
logger.info("Precomputing masks...")
for _ in tqdm.tqdm(range(num_precompute)):
_, patch_mask = generate_poisson_disc_mask(
(img_shape[0], img_shape[1]),
min_distance=hole_size * np.sqrt(2),
num_samples=num_samples,
patch_size=hole_size,
k=10,
)
self.precomputed_masks.append(patch_mask)
logger.info("Finished precomputing masks!")
else:
logger.info("Precomputing masks...")
img_height = img_shape[0]
img_width = img_shape[1]
max_area_to_remove = int((img_height * img_width) * max_perc_area_to_remove)
for _ in tqdm.tqdm(range(num_precompute)):
patch_mask = np.zeros(img_shape)
cur_num_holes = 0
if max_holes < 0:
num_holes = np.inf
else:
num_holes = np.random.randint(min_holes, max_holes + 1)
while (
cur_num_holes <= num_holes and np.count_nonzero(patch_mask) < max_area_to_remove
):
hole_width = np.random.randint(min_width, max_width + 1)
hole_height = np.random.randint(min_height, max_height + 1)
tl_x = np.random.randint(img_width - hole_width)
tl_y = np.random.randint(img_height - hole_height)
br_x = tl_x + hole_width
br_y = tl_y + hole_height
hole_area = (br_x - tl_x) * (br_y - tl_y)
if hole_area > max_area_to_remove and np.count_nonzero(patch_mask) == 0:
continue
patch_mask[tl_y:br_y, tl_x:br_x] = 1
cur_num_holes += 1
self.precomputed_masks.append(patch_mask)
logger.info("Finished precomputing masks!")
self.precomputed_masks.append(patch_mask)
logger.info("Finished precomputing masks!")

def create_random_masks(self):
"""Precomputes masks using random sampling"""

logger.info("Precomputing masks...")
img_height = self.img_shape[0]
img_width = self.img_shape[1]
max_area_to_remove = int((img_height * img_width) * self.max_perc_area_to_remove)
for _ in tqdm.tqdm(range(self.num_precompute)):
patch_mask = np.zeros(self.img_shape)
cur_num_holes = 0
if self.max_holes < 0:
num_holes = np.inf
else:
num_holes = np.random.randint(self.min_holes, self.max_holes + 1)
while cur_num_holes <= num_holes and np.count_nonzero(patch_mask) < max_area_to_remove:
hole_width = np.random.randint(self.min_width, self.max_width + 1)
hole_height = np.random.randint(self.min_height, self.max_height + 1)
tl_x = np.random.randint(img_width - hole_width)
tl_y = np.random.randint(img_height - hole_height)
br_x = tl_x + hole_width
br_y = tl_y + hole_height
hole_area = (br_x - tl_x) * (br_y - tl_y)
if hole_area > max_area_to_remove and np.count_nonzero(patch_mask) == 0:
continue
patch_mask[tl_y:br_y, tl_x:br_x] = 1
cur_num_holes += 1
self.precomputed_masks.append(patch_mask)
logger.info("Finished precomputing masks!")

def get_transform(self, img: np.ndarray) -> MedTransform:
"""
Expand Down Expand Up @@ -432,111 +441,122 @@ def __init__(

self.precomputed_masks = []
if sampling_pattern == "poisson":
assert max_height % 2 == 0, f"'max_height' (= {max_height}) must be even"
assert min_height == max_height, (
f"If sampling_pattern is 'poisson', min_height (= {min_height}) "
f"must equal max_height (= {max_height})"
)
assert is_square, "Only square patches are allowed if sampling pattern is 'poisson'"

patch_size = max_height
# Check max_perc_area_to_modify
# -- Assuming a hexagonal packing of circles will output the
# most number of samples when using Poisson disc sampling.
# -- From https://mathworld.wolfram.com/CirclePacking.html:
# Max packing density when using hexagonal packing is
# pi / (2 * sqrt(3))
max_pos_area = (img_shape[0] * img_shape[1]) * (np.pi / (2 * np.sqrt(3)))
max_num_patches = max_pos_area // ((np.pi / 2) * (patch_size ** 2))
max_pos_perc_area = np.round(
(max_num_patches * (patch_size ** 2)) / (img_shape[0] * img_shape[1]), decimals=3
self.create_poisson_disc_masks()
else:
self.create_random_masks()

def create_poisson_disc_masks(self):
"""Precomputes masks using poisson-disc sampling"""

assert self.max_height % 2 == 0, f"'max_height' (= {self.max_height}) must be even"
assert self.min_height == self.max_height, (
f"If sampling_pattern is 'poisson', min_height (= {self.min_height}) "
f"must equal max_height (= {self.max_height})"
)
assert self.is_square, "Only square patches are allowed if sampling pattern is 'poisson'"

patch_size = self.max_height
# Check max_perc_area_to_modify
# -- Assuming a hexagonal packing of circles will output the
# most number of samples when using Poisson disc sampling.
# -- From https://mathworld.wolfram.com/CirclePacking.html:
# Max packing density when using hexagonal packing is
# pi / (2 * sqrt(3))
max_pos_area = (self.img_shape[0] * self.img_shape[1]) * (np.pi / (2 * np.sqrt(3)))
max_num_patches = max_pos_area // ((np.pi / 2) * (patch_size ** 2))
max_pos_perc_area = np.round(
(max_num_patches * (patch_size ** 2)) / (self.img_shape[0] * self.img_shape[1]),
decimals=3,
)
if self.max_perc_area_to_modify >= max_pos_perc_area:
raise ValueError(
f"Value of 'max_perc_area_to_modify' "
f"(= {self.max_perc_area_to_modify}) is too large. "
f"An overestimate of the maximum possible % area that can "
f"be corrupted given {patch_size} x {patch_size} patches "
f"is: {max_pos_perc_area}. Make sure "
f"'max_perc_area_to_modify' is less than this value."
)
if max_perc_area_to_modify >= max_pos_perc_area:
raise ValueError(
f"Value of 'max_perc_area_to_modify' "
f"(= {max_perc_area_to_modify}) is too large. "
f"An overestimate of the maximum possible % area that can "
f"be corrupted given {patch_size} x {patch_size} patches "
f"is: {max_pos_perc_area}. Make sure "
f"'max_perc_area_to_modify' is less than this value."
)

# If `sampling_pattern` is "poisson", precompute
# `num_precompute` masks
num_samples = ((img_shape[0] * img_shape[1]) * max_perc_area_to_modify) // (
patch_size ** 2
# If `sampling_pattern` is "poisson", precompute
# `num_precompute` masks
num_samples = ((self.img_shape[0] * self.img_shape[1]) * self.max_perc_area_to_modify) // (
patch_size ** 2
)
assert num_samples >= 2, f"Number of samples (= {num_samples}) must be >= 2"
# Ensure number of samples is even
if num_samples % 2:
num_samples -= 1

logger.info("Precomputing masks...")
for _ in tqdm.tqdm(range(self.num_precompute)):
pd_mask, _ = generate_poisson_disc_mask(
(self.img_shape[0], self.img_shape[1]),
min_distance=patch_size * np.sqrt(2),
num_samples=num_samples,
patch_size=patch_size,
k=10,
)
assert num_samples >= 2, f"Number of samples (= {num_samples}) must be >= 2"
# Ensure number of samples is even
if num_samples % 2:
num_samples -= 1

logger.info("Precomputing masks...")
for _ in tqdm.tqdm(range(num_precompute)):
pd_mask, _ = generate_poisson_disc_mask(
(img_shape[0], img_shape[1]),
min_distance=patch_size * np.sqrt(2),
num_samples=num_samples,
patch_size=patch_size,
k=10,
)
self.precomputed_masks.append(pd_mask)
logger.info("Finished precomputing masks!")
else:
img_height = img_shape[0]
img_width = img_shape[1]
area_check = np.zeros((img_height, img_width))
max_area_to_modify = int((img_height * img_width) * max_perc_area_to_modify)

logger.info("Precomputing masks...")
for _ in tqdm.tqdm(range(num_precompute)):
patch_coords = []
area_check[:] = 0
if max_iterations < 0:
num_iterations = np.inf
self.precomputed_masks.append(pd_mask)
logger.info("Finished precomputing masks!")

def create_random_masks(self):
"""Precomputes masks using random sampling"""

img_height = self.img_shape[0]
img_width = self.img_shape[1]
area_check = np.zeros((img_height, img_width))
max_area_to_modify = int((img_height * img_width) * self.max_perc_area_to_modify)

logger.info("Precomputing masks...")
for _ in tqdm.tqdm(range(self.num_precompute)):
patch_coords = []
area_check[:] = 0
if self.max_iterations < 0:
num_iterations = np.inf
else:
num_iterations = np.random.randint(self.min_iterations, self.max_iterations + 1)
while (len(patch_coords) / 4) <= num_iterations and np.count_nonzero(
area_check
) < max_area_to_modify:
patch_height = np.random.randint(self.min_height, self.max_height + 1)
if self.is_square:
patch_width = patch_height
else:
patch_width = np.random.randint(self.min_width, self.max_width + 1)
# Get coordinates of first patch
tl_x_1 = np.random.randint(img_width - patch_width)
tl_y_1 = np.random.randint(img_height - patch_height)
br_x_1 = tl_x_1 + patch_width
br_y_1 = tl_y_1 + patch_height
patch_area = (br_x_1 - tl_x_1) * (br_y_1 - tl_y_1)
if patch_area > (max_area_to_modify / 2) and not patch_coords:
continue
patch_coords.append([tl_x_1, tl_y_1])
patch_coords.append([br_x_1, br_y_1])

# Get coordinates of second patch, ensuring the second
# patch will not overlap with the first patch
tl_y_2 = np.random.randint(img_height - patch_height)
if tl_y_2 <= tl_y_1 - patch_height or tl_y_2 >= br_y_1:
tl_x_2 = np.random.randint(img_width - patch_width)
else:
num_iterations = np.random.randint(min_iterations, max_iterations + 1)
while (len(patch_coords) / 4) <= num_iterations and np.count_nonzero(
area_check
) < max_area_to_modify:
patch_height = np.random.randint(min_height, max_height + 1)
if is_square:
patch_width = patch_height
else:
patch_width = np.random.randint(min_width, max_width + 1)
# Get coordinates of first patch
tl_x_1 = np.random.randint(img_width - patch_width)
tl_y_1 = np.random.randint(img_height - patch_height)
br_x_1 = tl_x_1 + patch_width
br_y_1 = tl_y_1 + patch_height
patch_area = (br_x_1 - tl_x_1) * (br_y_1 - tl_y_1)
if patch_area > (max_area_to_modify / 2) and not patch_coords:
continue
patch_coords.append([tl_x_1, tl_y_1])
patch_coords.append([br_x_1, br_y_1])

# Get coordinates of second patch, ensuring the second
# patch will not overlap with the first patch
tl_y_2 = np.random.randint(img_height - patch_height)
if tl_y_2 <= tl_y_1 - patch_height or tl_y_2 >= br_y_1:
tl_x_2 = np.random.randint(img_width - patch_width)
else:
possible_columns = list(range(tl_x_1 - patch_width + 1)) + list(
range(br_x_1, img_width - patch_width)
)
tl_x_2 = np.random.choice(np.array(possible_columns))
br_x_2 = tl_x_2 + patch_width
br_y_2 = tl_y_2 + patch_height
patch_coords.append([tl_x_2, tl_y_2])
patch_coords.append([br_x_2, br_y_2])

# Record the areas of image that were modified by the current
# pair of patches
area_check[tl_y_1:br_y_1, tl_x_1:br_x_1] = 1
area_check[tl_y_2:br_y_2, tl_x_2:br_x_2] = 1
coord_matrix = np.array(patch_coords).T
self.precomputed_masks.append(coord_matrix)
logger.info("Finished precomputing masks!")
possible_columns = list(range(tl_x_1 - patch_width + 1)) + list(
range(br_x_1, img_width - patch_width)
)
tl_x_2 = np.random.choice(np.array(possible_columns))
br_x_2 = tl_x_2 + patch_width
br_y_2 = tl_y_2 + patch_height
patch_coords.append([tl_x_2, tl_y_2])
patch_coords.append([br_x_2, br_y_2])

# Record the areas of image that were modified by the current
# pair of patches
area_check[tl_y_1:br_y_1, tl_x_1:br_x_1] = 1
area_check[tl_y_2:br_y_2, tl_x_2:br_x_2] = 1
coord_matrix = np.array(patch_coords).T
self.precomputed_masks.append(coord_matrix)
logger.info("Finished precomputing masks!")

def get_transform(self, img: np.ndarray) -> MedTransform:
"""
Expand Down