Skip to content

Commit d6e3f87

Browse files
authored
Merge pull request #879 from timeseriesAI/m2
M2
2 parents 2db1421 + af49493 commit d6e3f87

21 files changed

+1350
-986
lines changed

nbs/006_data.core.ipynb

Lines changed: 130 additions & 130 deletions
Large diffs are not rendered by default.

nbs/010_data.transforms.ipynb

Lines changed: 280 additions & 142 deletions
Large diffs are not rendered by default.

nbs/012_data.image.ipynb

Lines changed: 28 additions & 27 deletions
Large diffs are not rendered by default.

nbs/022_tslearner.ipynb

Lines changed: 92 additions & 97 deletions
Large diffs are not rendered by default.

nbs/026_callback.noisy_student.ipynb

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"metadata": {},
4444
"outputs": [],
4545
"source": [
46-
"#|export \n",
46+
"#|export\n",
4747
"from tsai.imports import *\n",
4848
"from tsai.utils import *\n",
4949
"from tsai.data.preprocessing import *\n",
@@ -61,26 +61,26 @@
6161
"#|export\n",
6262
"\n",
6363
"# This is an unofficial implementation of noisy student based on:\n",
64-
"# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification. \n",
64+
"# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification.\n",
6565
"# In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698).\n",
6666
"# Official tensorflow implementation available in https://github.com/google-research/noisystudent\n",
6767
"\n",
6868
"\n",
6969
"class NoisyStudent(Callback):\n",
70-
" \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise: \n",
70+
" \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise:\n",
7171
" - stochastic depth: .8\n",
7272
" - RandAugment: N=2, M=27\n",
7373
" - dropout: .5\n",
74-
" \n",
74+
"\n",
7575
" Steps:\n",
7676
" 1. Build the dl you will use as a teacher\n",
7777
" 2. Create dl2 with the pseudolabels (either soft or hard preds)\n",
7878
" 3. Pass any required batch_tfms to the callback\n",
79-
" \n",
79+
"\n",
8080
" \"\"\"\n",
81-
" \n",
82-
" def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True, \n",
83-
" pseudolabel_sample_weight:float=1., verbose=False): \n",
81+
"\n",
82+
" def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True,\n",
83+
" pseudolabel_sample_weight:float=1., verbose=False):\n",
8484
" r'''\n",
8585
" Args:\n",
8686
" dl2: dataloader with the pseudolabels\n",
@@ -90,18 +90,18 @@
9090
" do_setup: perform a transform setup on the labeled dataset.\n",
9191
" pseudolabel_sample_weight: weight of each pseudolabel sample relative to the labeled one of the loss.\n",
9292
" '''\n",
93-
" \n",
93+
"\n",
9494
" self.dl2, self.bs, self.l2pl_ratio, self.batch_tfms, self.do_setup, self.verbose = dl2, bs, l2pl_ratio, batch_tfms, do_setup, verbose\n",
9595
" self.pl_sw = pseudolabel_sample_weight\n",
96-
" \n",
96+
"\n",
9797
" def before_fit(self):\n",
9898
" if self.batch_tfms is None: self.batch_tfms = self.dls.train.after_batch\n",
9999
" self.old_bt = self.dls.train.after_batch # Remove and store dl.train.batch_tfms\n",
100100
" self.old_bs = self.dls.train.bs\n",
101-
" self.dls.train.after_batch = noop \n",
101+
" self.dls.train.after_batch = noop\n",
102102
"\n",
103103
" if self.do_setup and self.batch_tfms:\n",
104-
" for bt in self.batch_tfms: \n",
104+
" for bt in self.batch_tfms:\n",
105105
" bt.setup(self.dls.train)\n",
106106
"\n",
107107
" if self.bs is None: self.bs = self.dls.train.bs\n",
@@ -111,12 +111,12 @@
111111
" pv(f'labels / pseudolabels per training batch : {self.dls.train.bs} / {self.dl2.bs}', self.verbose)\n",
112112
" rel_weight = (self.dls.train.bs/self.dl2.bs) * (len(self.dl2.dataset)/len(self.dls.train.dataset))\n",
113113
" pv(f'relative labeled/ pseudolabel sample weight in dataset: {rel_weight:.1f}', self.verbose)\n",
114-
" \n",
114+
"\n",
115115
" self.dl2iter = iter(self.dl2)\n",
116-
" \n",
116+
"\n",
117117
" self.old_loss_func = self.learn.loss_func\n",
118118
" self.learn.loss_func = self.loss\n",
119-
" \n",
119+
"\n",
120120
" def before_batch(self):\n",
121121
" if self.training:\n",
122122
" X, y = self.x, self.y\n",
@@ -125,26 +125,26 @@
125125
" self.dl2iter = iter(self.dl2)\n",
126126
" X2, y2 = next(self.dl2iter)\n",
127127
" if y.ndim == 1 and y2.ndim == 2: y = torch.eye(self.learn.dls.c, device=y.device)[y]\n",
128-
" \n",
128+
"\n",
129129
" X_comb, y_comb = concat(X, X2), concat(y, y2)\n",
130-
" \n",
131-
" if self.batch_tfms is not None: \n",
130+
"\n",
131+
" if self.batch_tfms is not None:\n",
132132
" X_comb = compose_tfms(X_comb, self.batch_tfms, split_idx=0)\n",
133133
" y_comb = compose_tfms(y_comb, self.batch_tfms, split_idx=0)\n",
134134
" self.learn.xb = (X_comb,)\n",
135135
" self.learn.yb = (y_comb,)\n",
136136
" pv(f'\\nX: {X.shape} X2: {X2.shape} X_comb: {X_comb.shape}', self.verbose)\n",
137137
" pv(f'y: {y.shape} y2: {y2.shape} y_comb: {y_comb.shape}', self.verbose)\n",
138-
" \n",
139-
" def loss(self, output, target): \n",
138+
"\n",
139+
" def loss(self, output, target):\n",
140140
" if target.ndim == 2: _, target = target.max(dim=1)\n",
141-
" if self.training and self.pl_sw != 1: \n",
141+
" if self.training and self.pl_sw != 1:\n",
142142
" loss = (1 - self.pl_sw) * self.old_loss_func(output[:self.dls.train.bs], target[:self.dls.train.bs])\n",
143143
" loss += self.pl_sw * self.old_loss_func(output[self.dls.train.bs:], target[self.dls.train.bs:])\n",
144-
" return loss \n",
145-
" else: \n",
144+
" return loss\n",
145+
" else:\n",
146146
" return self.old_loss_func(output, target)\n",
147-
" \n",
147+
"\n",
148148
" def after_fit(self):\n",
149149
" self.dls.train.after_batch = self.old_bt\n",
150150
" self.learn.loss_func = self.old_loss_func\n",
@@ -170,7 +170,8 @@
170170
"outputs": [],
171171
"source": [
172172
"dsid = 'NATOPS'\n",
173-
"X, y, splits = get_UCR_data(dsid, return_split=False)"
173+
"X, y, splits = get_UCR_data(dsid, return_split=False)\n",
174+
"X = X.astype(np.float32)"
174175
]
175176
},
176177
{
@@ -229,10 +230,10 @@
229230
" <tbody>\n",
230231
" <tr>\n",
231232
" <td>0</td>\n",
232-
" <td>1.884984</td>\n",
233-
" <td>1.809759</td>\n",
234-
" <td>0.166667</td>\n",
235-
" <td>00:06</td>\n",
233+
" <td>1.782144</td>\n",
234+
" <td>1.758471</td>\n",
235+
" <td>0.250000</td>\n",
236+
" <td>00:00</td>\n",
236237
" </tr>\n",
237238
" </tbody>\n",
238239
"</table>"
@@ -249,7 +250,7 @@
249250
"output_type": "stream",
250251
"text": [
251252
"\n",
252-
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 58])\n",
253+
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 41])\n",
253254
"y: torch.Size([171]) y2: torch.Size([85]) y_comb: torch.Size([256])\n"
254255
]
255256
}
@@ -323,10 +324,10 @@
323324
" <tbody>\n",
324325
" <tr>\n",
325326
" <td>0</td>\n",
326-
" <td>1.894964</td>\n",
327-
" <td>1.814770</td>\n",
328-
" <td>0.177778</td>\n",
329-
" <td>00:03</td>\n",
327+
" <td>1.898401</td>\n",
328+
" <td>1.841182</td>\n",
329+
" <td>0.155556</td>\n",
330+
" <td>00:00</td>\n",
330331
" </tr>\n",
331332
" </tbody>\n",
332333
"</table>"
@@ -343,7 +344,7 @@
343344
"output_type": "stream",
344345
"text": [
345346
"\n",
346-
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 45])\n",
347+
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 51])\n",
347348
"y: torch.Size([171, 6]) y2: torch.Size([85, 6]) y_comb: torch.Size([256, 6])\n"
348349
]
349350
}
@@ -353,6 +354,7 @@
353354
"soft_preds = False\n",
354355
"\n",
355356
"pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)\n",
357+
"pseudolabels = pseudolabels.astype(np.float32)\n",
356358
"dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)\n",
357359
"dl2 = TSDataLoader(dsets2, num_workers=0)\n",
358360
"noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)\n",
@@ -380,9 +382,9 @@
380382
"name": "stdout",
381383
"output_type": "stream",
382384
"text": [
383-
"/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2023-01-21 14:30:23\n",
385+
"/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2024-02-10 21:53:24\n",
384386
"Correct notebook to script conversion! 😃\n",
385-
"Saturday 21/01/23 14:30:25 CET\n"
387+
"Saturday 10/02/24 21:53:27 CET\n"
386388
]
387389
},
388390
{

0 commit comments

Comments
 (0)