Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions configs/custom_configs/my_custom_dataset_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
TRAIN:
ENABLE: True
BATCH_SIZE: 4
EVAL_PERIOD: 5
CHECKPOINT_PERIOD: 1
AUTO_RESUME: False
CHECKPOINT_FILE_PATH: ""

TEST:
ENABLE: True
BATCH_SIZE: 8

DATA:
PATH_PREFIX: "<Enter your data path>"
PATH_TO_DATA_DIR: "<Enter your data path>"
PATH_LABEL_SEPARATOR: ","
DECODING_BACKEND: "pyav"
NUM_FRAMES: 32
SAMPLING_RATE: 2
TRAIN_JITTER_SCALES: [256, 320]
TRAIN_CROP_SIZE: 224
TEST_CROP_SIZE: 256
INPUT_CHANNEL_NUM: [3, 3]

MODEL:
NUM_CLASSES: 14 # <-- Change the number of classes here
ARCH: slowfast
MODEL_NAME: SlowFast
LOSS_FUNC: cross_entropy
DROPOUT_RATE: 0.5

SLOWFAST:
ALPHA: 4
BETA_INV: 8
FUSION_CONV_CHANNEL_RATIO: 2
FUSION_KERNEL_SZ: 7

RESNET:
ZERO_INIT_FINAL_BN: True
WIDTH_PER_GROUP: 64
NUM_GROUPS: 1
DEPTH: 50
TRANS_FUNC: bottleneck_transform
STRIDE_1X1: False
NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]

NONLOCAL:
LOCATION: [[[], []], [[], []], [[], []], [[], []]]
GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
INSTANTIATION: dot_product

BN:
USE_PRECISE_STATS: True
NUM_BATCHES_PRECISE: 200

SOLVER:
BASE_LR: 0.0005
LR_POLICY: cosine
MAX_EPOCH: 50
MOMENTUM: 0.9
WEIGHT_DECAY: 1e-4
WARMUP_EPOCHS: 5.0
WARMUP_START_LR: 0.0001
OPTIMIZING_METHOD: sgd

DATA_LOADER:
NUM_WORKERS: 2
PIN_MEMORY: True

NUM_GPUS: 1
NUM_SHARDS: 1
RNG_SEED: 0
OUTPUT_DIR: "./checkpoints/isl_split_model"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
"opencv-python",
"pandas",
"torchvision>=0.4.2",
"PIL",
"sklearn",
"pillow",
"scikit-learn",
"tensorboard",
"fairscale",
],
Expand Down
1 change: 1 addition & 0 deletions slowfast/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .imagenet import Imagenet # noqa
from .kinetics import Kinetics # noqa
from .ssv2 import Ssv2 # noqa
from .my_custom_dataset import Custom

try:
from .ptv_datasets import Ptvcharades, Ptvkinetics, Ptvssv2 # noqa
Expand Down
29 changes: 13 additions & 16 deletions slowfast/datasets/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def torchvision_decode(
decode_all_video (bool): if True, the entire video was decoded.
"""
# Convert the bytes to a tensor.
video_tensor = torch.from_numpy(np.frombuffer(video_handle, dtype=np.uint8))
# Converted the video_handle into numpy array
video_tensor = torch.from_numpy(np.frombuffer(np.array(video_handle), dtype=np.uint8))

decode_all_video = True
video_start_pts, video_end_pts = 0, -1
Expand Down Expand Up @@ -418,6 +419,10 @@ def pyav_decode(
else:
# Perform selective decoding.
decode_all_video = False
## Changed the sampling rate code
sampling_rate = sampling_rate[0] if isinstance(sampling_rate, list) else sampling_rate
num_frames = num_frames[0] if isinstance(num_frames, list) else num_frames

clip_size = np.maximum(
1.0, np.ceil(sampling_rate * (num_frames - 1) / target_fps * fps)
)
Expand Down Expand Up @@ -572,28 +577,20 @@ def decode(
[None] * num_decode,
)
augment_vid = gaussian_prob > 0.0 or time_diff_prob > 0.0
## Improved the decoding frames
for k in range(num_decode):
T = num_frames[k]
# Perform temporal sampling from the decoded video.

if decode_all_video:
frames = frames_decoded[0]
if augment_vid:
frames = frames.clone()
start_idx, end_idx = (
start_end_delta_time[k, 0],
start_end_delta_time[k, 1],
)
frames = frames_decoded[0].clone() if augment_vid else frames_decoded[0]
start_idx, end_idx = start_end_delta_time[k, 0], start_end_delta_time[k, 1]
else:
frames = frames_decoded[k]
# video is already trimmed so we just need subsampling
start_idx, end_idx, clip_position = get_start_end_idx(
frames.shape[0], clip_sizes[k], 0, 1
)
start_idx, end_idx, _ = get_start_end_idx(frames.shape[0], clip_sizes[k], 0, 1)

if augment_vid:
frames, time_diff_aug[k] = transform.augment_raw_frames(
frames, time_diff_prob, gaussian_prob
)
frames, time_diff_aug[k] = transform.augment_raw_frames(frames, time_diff_prob, gaussian_prob)

frames_k = temporal_sampling(frames, start_idx, end_idx, T)
frames_out[k] = frames_k

Expand Down
47 changes: 32 additions & 15 deletions slowfast/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,11 @@ def __getitem__(self, index):
for _ in range(num_aug):
idx += 1
f_out[idx] = frames_decoded[i].clone()
time_idx_out[idx] = time_idx_decoded[i, :]
## Adding time_index only if it's decoded
if time_idx_decoded is not None and isinstance(time_idx_decoded, np.ndarray):
time_idx_out[idx] = time_idx_decoded[i, :]
else:
time_idx_out[idx] = None

f_out[idx] = f_out[idx].float()
f_out[idx] = f_out[idx] / 255.0
Expand Down Expand Up @@ -450,22 +454,35 @@ def __getitem__(self, index):
if self.cfg.AUG.GEN_MASK_LOADER:
mask = self._gen_mask()
f_out[idx] = f_out[idx] + [torch.Tensor(), mask]
# Final outputs
# Ensure everything is collate-compatible
frames = f_out[0] if num_out == 1 else f_out
time_idx = np.array(time_idx_out)
if (
num_aug * num_decode > 1
and not self.cfg.MODEL.MODEL_NAME == "ContrastiveModel"
):
label = [label] * num_aug * num_decode
index = [index] * num_aug * num_decode
if self.cfg.DATA.DUMMY_LOAD:
if self.dummy_output is None:
self.dummy_output = (frames, label, index, time_idx, {})
time_idx = torch.stack([
torch.tensor(t, dtype=torch.int32) if t is not None else torch.zeros_like(torch.tensor([0], dtype=torch.int32))
for t in time_idx_out
])


# Fix label and index
if num_aug * num_decode > 1 and not self.cfg.MODEL.MODEL_NAME == "ContrastiveModel":
label = [label] * (num_aug * num_decode)
index = [index] * (num_aug * num_decode)

# Convert everything to torch/numpy-compatible types
if isinstance(frames, list):
frames = [f.clone() for f in frames]
elif isinstance(frames, torch.Tensor):
frames = frames.clone()

label = torch.tensor(label) if isinstance(label, list) else label
index = torch.tensor(index) if isinstance(index, list) else index

# DO NOT convert time_idx again — already a tensor
# time_idx = torch.tensor(time_idx)

return frames, label, index, time_idx, {}
else:
logger.warning(
"Failed to fetch video after {} retries.".format(self._num_retries)
)



def _gen_mask(self):
if self.cfg.AUG.MASK_TUBE:
Expand Down
25 changes: 20 additions & 5 deletions slowfast/datasets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
from . import utils as utils
from .build import build_dataset

## Adding safe_collate function
def safe_collate(batch):
batch = [b for b in batch if b is not None]
if len(batch) == 0:
return None
return default_collate(batch)


def multiple_samples_collate(batch, fold=False):
"""
Expand Down Expand Up @@ -121,7 +128,7 @@ def construct_loader(cfg, split, is_precise_bn=False):
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
collate_fn=safe_collate,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
else:
Expand All @@ -138,10 +145,18 @@ def construct_loader(cfg, split, is_precise_bn=False):
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
worker_init_fn=utils.loader_worker_init_fn(dataset),
collate_func=safe_collate
)
else:
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)

# Define safe wrapper for multiple_samples_collate
def multiple_samples_safe(batch):
batch = [b for b in batch if b is not None]
if len(batch) == 0:
return None
return multiple_samples_collate(batch, fold="imagenet" in dataset_name)
# Create a loader
if cfg.DETECTION.ENABLE:
collate_func = detection_collate
Expand All @@ -154,11 +169,11 @@ def construct_loader(cfg, split, is_precise_bn=False):
and split in ["train"]
and not cfg.MODEL.MODEL_NAME == "ContrastiveModel"
):
collate_func = partial(
multiple_samples_collate, fold="imagenet" in dataset_name
)
## Changed here
collate_func = multiple_samples_collate
else:
collate_func = None
collate_func = safe_collate

loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
Expand Down
10 changes: 4 additions & 6 deletions slowfast/datasets/multigrid_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import torch
from torch.utils.data.sampler import Sampler

TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])

if TORCH_MAJOR >= 1 and TORCH_MINOR >= 8:
_int_classes = int
else:
# Modified this import
try:
from torch._six import int_classes as _int_classes
except ImportError:
_int_classes = (int,)


class ShortCycleBatchSampler(Sampler):
Expand Down
8 changes: 8 additions & 0 deletions slowfast/datasets/my_custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from slowfast.datasets.build import DATASET_REGISTRY
from slowfast.datasets.kinetics import Kinetics

@DATASET_REGISTRY.register()
class Custom(Kinetics):
def __init__(self, cfg, split):
super().__init__(cfg, split)
# You can override methods here if needed