|
43 | 43 | "metadata": {}, |
44 | 44 | "outputs": [], |
45 | 45 | "source": [ |
46 | | - "#|export \n", |
| 46 | + "#|export\n", |
47 | 47 | "from tsai.imports import *\n", |
48 | 48 | "from tsai.utils import *\n", |
49 | 49 | "from tsai.data.preprocessing import *\n", |
|
61 | 61 | "#|export\n", |
62 | 62 | "\n", |
63 | 63 | "# This is an unofficial implementation of noisy student based on:\n", |
64 | | - "# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification. \n", |
| 64 | + "# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification.\n", |
65 | 65 | "# In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698).\n", |
66 | 66 | "# Official tensorflow implementation available in https://github.com/google-research/noisystudent\n", |
67 | 67 | "\n", |
68 | 68 | "\n", |
69 | 69 | "class NoisyStudent(Callback):\n", |
70 | | - " \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise: \n", |
| 70 | + " \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise:\n", |
71 | 71 | " - stochastic depth: .8\n", |
72 | 72 | " - RandAugment: N=2, M=27\n", |
73 | 73 | " - dropout: .5\n", |
74 | | - " \n", |
| 74 | + "\n", |
75 | 75 | " Steps:\n", |
76 | 76 | " 1. Build the dl you will use as a teacher\n", |
77 | 77 | " 2. Create dl2 with the pseudolabels (either soft or hard preds)\n", |
78 | 78 | " 3. Pass any required batch_tfms to the callback\n", |
79 | | - " \n", |
| 79 | + "\n", |
80 | 80 | " \"\"\"\n", |
81 | | - " \n", |
82 | | - " def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True, \n", |
83 | | - " pseudolabel_sample_weight:float=1., verbose=False): \n", |
| 81 | + "\n", |
| 82 | + " def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True,\n", |
| 83 | + " pseudolabel_sample_weight:float=1., verbose=False):\n", |
84 | 84 | " r'''\n", |
85 | 85 | " Args:\n", |
86 | 86 | " dl2: dataloader with the pseudolabels\n", |
|
90 | 90 | " do_setup: perform a transform setup on the labeled dataset.\n", |
91 | 91 | " pseudolabel_sample_weight: weight of each pseudolabel sample relative to the labeled one of the loss.\n", |
92 | 92 | " '''\n", |
93 | | - " \n", |
| 93 | + "\n", |
94 | 94 | " self.dl2, self.bs, self.l2pl_ratio, self.batch_tfms, self.do_setup, self.verbose = dl2, bs, l2pl_ratio, batch_tfms, do_setup, verbose\n", |
95 | 95 | " self.pl_sw = pseudolabel_sample_weight\n", |
96 | | - " \n", |
| 96 | + "\n", |
97 | 97 | " def before_fit(self):\n", |
98 | 98 | " if self.batch_tfms is None: self.batch_tfms = self.dls.train.after_batch\n", |
99 | 99 | " self.old_bt = self.dls.train.after_batch # Remove and store dl.train.batch_tfms\n", |
100 | 100 | " self.old_bs = self.dls.train.bs\n", |
101 | | - " self.dls.train.after_batch = noop \n", |
| 101 | + " self.dls.train.after_batch = noop\n", |
102 | 102 | "\n", |
103 | 103 | " if self.do_setup and self.batch_tfms:\n", |
104 | | - " for bt in self.batch_tfms: \n", |
| 104 | + " for bt in self.batch_tfms:\n", |
105 | 105 | " bt.setup(self.dls.train)\n", |
106 | 106 | "\n", |
107 | 107 | " if self.bs is None: self.bs = self.dls.train.bs\n", |
|
111 | 111 | " pv(f'labels / pseudolabels per training batch : {self.dls.train.bs} / {self.dl2.bs}', self.verbose)\n", |
112 | 112 | " rel_weight = (self.dls.train.bs/self.dl2.bs) * (len(self.dl2.dataset)/len(self.dls.train.dataset))\n", |
113 | 113 | " pv(f'relative labeled/ pseudolabel sample weight in dataset: {rel_weight:.1f}', self.verbose)\n", |
114 | | - " \n", |
| 114 | + "\n", |
115 | 115 | " self.dl2iter = iter(self.dl2)\n", |
116 | | - " \n", |
| 116 | + "\n", |
117 | 117 | " self.old_loss_func = self.learn.loss_func\n", |
118 | 118 | " self.learn.loss_func = self.loss\n", |
119 | | - " \n", |
| 119 | + "\n", |
120 | 120 | " def before_batch(self):\n", |
121 | 121 | " if self.training:\n", |
122 | 122 | " X, y = self.x, self.y\n", |
|
125 | 125 | " self.dl2iter = iter(self.dl2)\n", |
126 | 126 | " X2, y2 = next(self.dl2iter)\n", |
127 | 127 | " if y.ndim == 1 and y2.ndim == 2: y = torch.eye(self.learn.dls.c, device=y.device)[y]\n", |
128 | | - " \n", |
| 128 | + "\n", |
129 | 129 | " X_comb, y_comb = concat(X, X2), concat(y, y2)\n", |
130 | | - " \n", |
131 | | - " if self.batch_tfms is not None: \n", |
| 130 | + "\n", |
| 131 | + " if self.batch_tfms is not None:\n", |
132 | 132 | " X_comb = compose_tfms(X_comb, self.batch_tfms, split_idx=0)\n", |
133 | 133 | " y_comb = compose_tfms(y_comb, self.batch_tfms, split_idx=0)\n", |
134 | 134 | " self.learn.xb = (X_comb,)\n", |
135 | 135 | " self.learn.yb = (y_comb,)\n", |
136 | 136 | " pv(f'\\nX: {X.shape} X2: {X2.shape} X_comb: {X_comb.shape}', self.verbose)\n", |
137 | 137 | " pv(f'y: {y.shape} y2: {y2.shape} y_comb: {y_comb.shape}', self.verbose)\n", |
138 | | - " \n", |
139 | | - " def loss(self, output, target): \n", |
| 138 | + "\n", |
| 139 | + " def loss(self, output, target):\n", |
140 | 140 | " if target.ndim == 2: _, target = target.max(dim=1)\n", |
141 | | - " if self.training and self.pl_sw != 1: \n", |
| 141 | + " if self.training and self.pl_sw != 1:\n", |
142 | 142 | " loss = (1 - self.pl_sw) * self.old_loss_func(output[:self.dls.train.bs], target[:self.dls.train.bs])\n", |
143 | 143 | " loss += self.pl_sw * self.old_loss_func(output[self.dls.train.bs:], target[self.dls.train.bs:])\n", |
144 | | - " return loss \n", |
145 | | - " else: \n", |
| 144 | + " return loss\n", |
| 145 | + " else:\n", |
146 | 146 | " return self.old_loss_func(output, target)\n", |
147 | | - " \n", |
| 147 | + "\n", |
148 | 148 | " def after_fit(self):\n", |
149 | 149 | " self.dls.train.after_batch = self.old_bt\n", |
150 | 150 | " self.learn.loss_func = self.old_loss_func\n", |
|
170 | 170 | "outputs": [], |
171 | 171 | "source": [ |
172 | 172 | "dsid = 'NATOPS'\n", |
173 | | - "X, y, splits = get_UCR_data(dsid, return_split=False)" |
| 173 | + "X, y, splits = get_UCR_data(dsid, return_split=False)\n", |
| 174 | + "X = X.astype(np.float32)" |
174 | 175 | ] |
175 | 176 | }, |
176 | 177 | { |
|
229 | 230 | " <tbody>\n", |
230 | 231 | " <tr>\n", |
231 | 232 | " <td>0</td>\n", |
232 | | - " <td>1.884984</td>\n", |
233 | | - " <td>1.809759</td>\n", |
234 | | - " <td>0.166667</td>\n", |
235 | | - " <td>00:06</td>\n", |
| 233 | + " <td>1.782144</td>\n", |
| 234 | + " <td>1.758471</td>\n", |
| 235 | + " <td>0.250000</td>\n", |
| 236 | + " <td>00:00</td>\n", |
236 | 237 | " </tr>\n", |
237 | 238 | " </tbody>\n", |
238 | 239 | "</table>" |
|
249 | 250 | "output_type": "stream", |
250 | 251 | "text": [ |
251 | 252 | "\n", |
252 | | - "X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 58])\n", |
| 253 | + "X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 41])\n", |
253 | 254 | "y: torch.Size([171]) y2: torch.Size([85]) y_comb: torch.Size([256])\n" |
254 | 255 | ] |
255 | 256 | } |
|
323 | 324 | " <tbody>\n", |
324 | 325 | " <tr>\n", |
325 | 326 | " <td>0</td>\n", |
326 | | - " <td>1.894964</td>\n", |
327 | | - " <td>1.814770</td>\n", |
328 | | - " <td>0.177778</td>\n", |
329 | | - " <td>00:03</td>\n", |
| 327 | + " <td>1.898401</td>\n", |
| 328 | + " <td>1.841182</td>\n", |
| 329 | + " <td>0.155556</td>\n", |
| 330 | + " <td>00:00</td>\n", |
330 | 331 | " </tr>\n", |
331 | 332 | " </tbody>\n", |
332 | 333 | "</table>" |
|
343 | 344 | "output_type": "stream", |
344 | 345 | "text": [ |
345 | 346 | "\n", |
346 | | - "X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 45])\n", |
| 347 | + "X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 51])\n", |
347 | 348 | "y: torch.Size([171, 6]) y2: torch.Size([85, 6]) y_comb: torch.Size([256, 6])\n" |
348 | 349 | ] |
349 | 350 | } |
|
353 | 354 | "soft_preds = False\n", |
354 | 355 | "\n", |
355 | 356 | "pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)\n", |
| 357 | + "pseudolabels = pseudolabels.astype(np.float32)\n", |
356 | 358 | "dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)\n", |
357 | 359 | "dl2 = TSDataLoader(dsets2, num_workers=0)\n", |
358 | 360 | "noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)\n", |
|
380 | 382 | "name": "stdout", |
381 | 383 | "output_type": "stream", |
382 | 384 | "text": [ |
383 | | - "/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2023-01-21 14:30:23\n", |
| 385 | + "/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2024-02-10 21:53:24\n", |
384 | 386 | "Correct notebook to script conversion! 😃\n", |
385 | | - "Saturday 21/01/23 14:30:25 CET\n" |
| 387 | + "Saturday 10/02/24 21:53:27 CET\n" |
386 | 388 | ] |
387 | 389 | }, |
388 | 390 | { |
|
0 commit comments