Skip to content

Commit dd620d3

Browse files
committed
update bn and data aug
1 parent 5f66e7c commit dd620d3

File tree

2 files changed

+7
-169
lines changed

2 files changed

+7
-169
lines changed

chapter4_CNN/batch-normalization.ipynb

Lines changed: 1 addition & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,7 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# 批标准化\n",
8-
"在我们正式进入模型的构建和训练之前,我们会先讲一讲数据预处理和批标准化,因为模型训练并不容易,特别是一些非常复杂的模型,并不能非常好的训练得到收敛的结果,所以对数据增加一些预处理,同时使用批标准化能够得到非常好的收敛结果,这也是卷积网络能够训练到非常深的层的一个重要原因。"
9-
]
10-
},
11-
{
12-
"cell_type": "markdown",
13-
"metadata": {},
14-
"source": [
15-
"## 数据预处理\n",
16-
"目前数据预处理最常见的方法就是中心化和标准化,中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间,下面是一个简单的图示\n",
17-
"\n",
18-
"![](https://ws1.sinaimg.cn/large/006tKfTcly1fmqouzer3xj30ij06n0t8.jpg)\n",
19-
"\n",
20-
"这两种方法非常的常见,如果你还记得,前面我们在神经网络的部分就已经使用了这个方法实现了数据标准化,至于另外一些方法,比如 PCA 或者 白噪声已经用得非常少了。"
21-
]
22-
},
23-
{
24-
"cell_type": "markdown",
25-
"metadata": {},
26-
"source": [
27-
"## Batch Normalization\n",
28-
"前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。\n",
29-
"\n",
30-
"所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。"
31-
]
32-
},
33-
{
34-
"cell_type": "markdown",
35-
"metadata": {},
36-
"source": [
37-
"batch normalization 的实现非常简单,对于给定的一个 batch 的数据 $B = \\{x_1, x_2, \\cdots, x_m\\}$算法的公式如下\n",
38-
"\n",
39-
"$$\n",
40-
"\\mu_B = \\frac{1}{m} \\sum_{i=1}^m x_i\n",
41-
"$$\n",
42-
"$$\n",
43-
"\\sigma^2_B = \\frac{1}{m} \\sum_{i=1}^m (x_i - \\mu_B)^2\n",
44-
"$$\n",
45-
"$$\n",
46-
"\\hat{x}_i = \\frac{x_i - \\mu_B}{\\sqrt{\\sigma^2_B + \\epsilon}}\n",
47-
"$$\n",
48-
"$$\n",
49-
"y_i = \\gamma \\hat{x}_i + \\beta\n",
50-
"$$"
51-
]
52-
},
53-
{
54-
"cell_type": "markdown",
55-
"metadata": {},
56-
"source": [
57-
"第一行和第二行是计算出一个 batch 中数据的均值和方差,接着使用第三个公式对 batch 中的每个数据点做标准化,$\\epsilon$ 是为了计算稳定引入的一个小的常数,通常取 $10^{-5}$,最后利用权重修正得到最后的输出结果,非常的简单,下面我们可以实现一下简单的一维的情况,也就是神经网络中的情况"
7+
"# 批标准化"
588
]
599
},
6010
{
@@ -146,19 +96,6 @@
14696
"print(y)"
14797
]
14898
},
149-
{
150-
"cell_type": "markdown",
151-
"metadata": {},
152-
"source": [
153-
"可以看到这里一共是 5 个数据点,三个特征,每一列表示一个特征的不同数据点,使用批标准化之后,每一列都变成了标准的正态分布\n",
154-
"\n",
155-
"这个时候会出现一个问题,就是测试的时候该使用批标准化吗?\n",
156-
"\n",
157-
"答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替\n",
158-
"\n",
159-
"下面我们实现以下能够区分训练状态和测试状态的批标准化方法"
160-
]
161-
},
16299
{
163100
"cell_type": "code",
164101
"execution_count": 4,
@@ -320,13 +257,6 @@
320257
"train(net, train_data, test_data, 10, optimizer, criterion)"
321258
]
322259
},
323-
{
324-
"cell_type": "markdown",
325-
"metadata": {},
326-
"source": [
327-
"这里的 $\\gamma$ 和 $\\beta$ 都作为参数进行训练,初始化为随机的高斯分布,`moving_mean` 和 `moving_var` 都初始化为 0,并不是更新的参数,训练完 10 次之后,我们可以看看移动平均和移动方差被修改为了多少"
328-
]
329-
},
330260
{
331261
"cell_type": "code",
332262
"execution_count": 11,
@@ -360,20 +290,6 @@
360290
"print(net.moving_mean[:10])"
361291
]
362292
},
363-
{
364-
"cell_type": "markdown",
365-
"metadata": {},
366-
"source": [
367-
"可以看到,这些值已经在训练的过程中进行了修改,在测试过程中,我们不需要再计算均值和方差,直接使用移动平均和移动方差即可"
368-
]
369-
},
370-
{
371-
"cell_type": "markdown",
372-
"metadata": {},
373-
"source": [
374-
"作为对比,我们看看不使用批标准化的结果"
375-
]
376-
},
377293
{
378294
"cell_type": "code",
379295
"execution_count": 12,
@@ -409,27 +325,6 @@
409325
"train(no_bn_net, train_data, test_data, 10, optimizer, criterion)"
410326
]
411327
},
412-
{
413-
"cell_type": "markdown",
414-
"metadata": {},
415-
"source": [
416-
"可以看到虽然最后的结果两种情况一样,但是如果我们看前几次的情况,可以看到使用批标准化的情况能够更快的收敛,因为这只是一个小网络,所以用不用批标准化都能够收敛,但是对于更加深的网络,使用批标准化在训练的时候能够很快地收敛"
417-
]
418-
},
419-
{
420-
"cell_type": "markdown",
421-
"metadata": {},
422-
"source": [
423-
"从上面可以看到,我们自己实现了 2 维情况的批标准化,对应于卷积的 4 维情况的标准化是类似的,只需要沿着通道的维度进行均值和方差的计算,但是我们自己实现批标准化是很累的,pytorch 当然也为我们内置了批标准化的函数,一维和二维分别是 `torch.nn.BatchNorm1d()` 和 `torch.nn.BatchNorm2d()`,不同于我们的实现,pytorch 不仅将 $\\gamma$ 和 $\\beta$ 作为训练的参数,也将 `moving_mean` 和 `moving_var` 也作为参数进行训练"
424-
]
425-
},
426-
{
427-
"cell_type": "markdown",
428-
"metadata": {},
429-
"source": [
430-
"下面我们在卷积网络下试用一下批标准化看看效果"
431-
]
432-
},
433328
{
434329
"cell_type": "code",
435330
"execution_count": null,
@@ -562,13 +457,6 @@
562457
"source": [
563458
"train(net, train_data, test_data, 5, optimizer, criterion)"
564459
]
565-
},
566-
{
567-
"cell_type": "markdown",
568-
"metadata": {},
569-
"source": [
570-
"之后介绍一些著名的网络结构的时候,我们会慢慢认识到批标准化的重要性,使用 pytorch 能够非常方便地添加批标准化层"
571-
]
572460
}
573461
],
574462
"metadata": {

chapter4_CNN/data-augumentation.ipynb

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,7 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# 数据增强\n",
8-
"前面我们已经讲了几个非常著名的卷积网络的结构,但是单单只靠这些网络并不能取得 state-of-the-art 的结果,现实问题往往更加复杂,所以为了在现实中的数据集上取得成功,还需要应用一些额外的数据增强方法和网训练络的技巧。\n",
9-
"\n",
10-
"2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。"
11-
]
12-
},
13-
{
14-
"cell_type": "markdown",
15-
"metadata": {},
16-
"source": [
17-
"## 常用的数据增强方法\n",
18-
"常用的数据增强方法如下: \n",
19-
"1.对图片进行一定比例缩放 \n",
20-
"2.对图片进行随机位置的截取 \n",
21-
"3.对图片进行随机的水平和竖直翻转 \n",
22-
"4.对图片进行随机角度的旋转 \n",
23-
"5.对图片进行亮度、对比度和颜色的随机变化\n",
24-
"\n",
25-
"这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法"
7+
"# 数据增强"
268
]
279
},
2810
{
@@ -66,8 +48,7 @@
6648
"cell_type": "markdown",
6749
"metadata": {},
6850
"source": [
69-
"### 随机比例放缩\n",
70-
"随机比例缩放主要使用的是 `torchvision.transforms.Resize()` 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看[文档](http://pytorch.org/docs/0.3.0/torchvision/transforms.html)"
51+
"### 随机比例放缩"
7152
]
7253
},
7354
{
@@ -109,8 +90,7 @@
10990
"cell_type": "markdown",
11091
"metadata": {},
11192
"source": [
112-
"### 随机位置截取\n",
113-
"随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 `torchvision.transforms.RandomCrop()`,传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 `torchvision.transforms.CenterCrop()`,同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取"
93+
"### 随机位置截取"
11494
]
11595
},
11696
{
@@ -192,8 +172,7 @@
192172
"cell_type": "markdown",
193173
"metadata": {},
194174
"source": [
195-
"### 随机的水平和竖直方向翻转\n",
196-
"对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 `torchvision.transforms.RandomHorizontalFlip()` 和 `torchvision.transforms.RandomVerticalFlip()`"
175+
"### 随机的水平和竖直方向翻转"
197176
]
198177
},
199178
{
@@ -250,8 +229,7 @@
250229
"cell_type": "markdown",
251230
"metadata": {},
252231
"source": [
253-
"### 随机角度旋转\n",
254-
"一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 `torchvision.transforms.RandomRotation()` 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转"
232+
"### 随机角度旋转"
255233
]
256234
},
257235
{
@@ -282,8 +260,7 @@
282260
"cell_type": "markdown",
283261
"metadata": {},
284262
"source": [
285-
"### 亮度、对比度和颜色的变化\n",
286-
"除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 `torchvision.transforms.ColorJitter()` 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色"
263+
"### 亮度、对比度和颜色的变化"
287264
]
288265
},
289266
{
@@ -361,15 +338,6 @@
361338
"color_im"
362339
]
363340
},
364-
{
365-
"cell_type": "markdown",
366-
"metadata": {},
367-
"source": [
368-
"\n",
369-
"\n",
370-
"上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 `torchvision.transforms.Compose()`,下面我们举个例子"
371-
]
372-
},
373341
{
374342
"cell_type": "code",
375343
"execution_count": 23,
@@ -429,15 +397,6 @@
429397
"plt.show()"
430398
]
431399
},
432-
{
433-
"cell_type": "markdown",
434-
"metadata": {},
435-
"source": [
436-
"可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据\n",
437-
"\n",
438-
"下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用前面讲的 ResNet 进行训练 "
439-
]
440-
},
441400
{
442401
"cell_type": "code",
443402
"execution_count": 1,
@@ -599,15 +558,6 @@
599558
"source": [
600559
"train(net, train_data, test_data, 10, optimizer, criterion)"
601560
]
602-
},
603-
{
604-
"cell_type": "markdown",
605-
"metadata": {},
606-
"source": [
607-
"从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。\n",
608-
"\n",
609-
"而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。"
610-
]
611561
}
612562
],
613563
"metadata": {

0 commit comments

Comments
 (0)