Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use repository line dataset
  • Loading branch information
konitaro524 committed Jun 7, 2025
commit 1bd3c05cce0ec2f49dd1301c2e00c2b61722c41b
24 changes: 2 additions & 22 deletions line_forecast.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,9 @@
import torch
import torch.nn as nn
from models.s4.s4d import S4D
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader


class LineDataset(Dataset):
"""Synthetic dataset of linear sequences for forecasting."""

def __init__(self, seq_len=10, pred_len=1, size=1000):
super().__init__()
self.seq_len = seq_len
self.pred_len = pred_len
self.size = size

def __len__(self):
return self.size

def __getitem__(self, idx):
slope = torch.rand(1) * 2 - 1 # [-1, 1]
intercept = torch.rand(1) * 2 - 1
t = torch.arange(self.seq_len + self.pred_len, dtype=torch.float)
y = slope * t + intercept
x = y[: self.seq_len].unsqueeze(-1)
target = y[self.seq_len :].unsqueeze(-1)
return x, target
from src.dataloaders.datasets.line import LineDataset


class ForecastModel(nn.Module):
Expand Down