Skip to content

Commit 6165fb5

Browse files
committed
update
1 parent 472864c commit 6165fb5

File tree

6 files changed

+92
-2
lines changed

6 files changed

+92
-2
lines changed

cahpter3_MLP/logistic_regression/logistic_regression.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,18 @@ def forward(self, x):
6363
out = logistic_model(x)
6464
loss = criterion(out, y)
6565
print_loss = loss.data[0]
66-
66+
mask = out.ge(0.5).float()
67+
correct = (mask == y).sum()
68+
acc = correct.data[0] / x.size(0)
6769
# ===================backward=================
6870
optimizer.zero_grad()
6971
loss.backward()
7072
optimizer.step()
7173
if (epoch+1) % 1000 == 0:
74+
print('*'*10)
7275
print('epoch {}'.format(epoch+1))
7376
print('loss is {:.4f}'.format(print_loss))
74-
77+
print('acc is {:.4f}'.format(acc))
7578
torch.save(logistic_model.state_dict, './logistic_regression.pth')
7679
# ====================plot classification=================
7780
w0, w1 = logistic_model.lr.weight[0]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
__author__ = 'sherlock'
2+
3+
import torch
4+
from torch import nn
5+
from torch.autograd import Variable
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
9+
torch.manual_seed(2017)
10+
# get the data and preprocessing
11+
with open('data.txt', 'r') as f:
12+
data_list = f.readlines()
13+
data_list = [i.split('\n')[0] for i in data_list]
14+
data_list = [i.split(',') for i in data_list]
15+
data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]
16+
17+
x0 = list(filter(lambda x: x[-1] == 0.0, data))
18+
x1 = list(filter(lambda x: x[-1] == 1.0, data))
19+
plot_x0_0 = [i[0] for i in x0]
20+
plot_x0_1 = [i[1] for i in x0]
21+
plot_x1_0 = [i[0] for i in x1]
22+
plot_x1_1 = [i[1] for i in x1]
23+
24+
plt.plot(plot_x0_0, plot_x0_1, 'ro', label='x_0')
25+
plt.plot(plot_x1_0, plot_x1_1, 'bo', label='x_1')
26+
plt.legend(loc='best')
27+
28+
# transform to tensor
29+
np_data = np.array(data, dtype=np.float32)
30+
x_data = torch.from_numpy(np_data[:, 0:2])
31+
y_data = torch.from_numpy(np_data[:, -1])
32+
33+
34+
# define logistic regression
35+
class LogisticRegression(nn.Module):
36+
def __init__(self):
37+
super(LogisticRegression, self).__init__()
38+
self.lr = nn.Linear(2, 1)
39+
self.sm = nn.Sigmoid()
40+
41+
def forward(self, x):
42+
x = self.lr(x)
43+
x = self.sm(x)
44+
return x
45+
46+
47+
logistic_model = LogisticRegression()
48+
if torch.cuda.is_available():
49+
logistic_model.cuda()
50+
51+
criterion = nn.BCELoss()
52+
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=1e-3,
53+
momentum=0.9)
54+
55+
for epoch in range(50000):
56+
if torch.cuda.is_available():
57+
x = Variable(x_data).cuda()
58+
y = Variable(y_data).cuda()
59+
else:
60+
x = Variable(x_data)
61+
y = Variable(y_data)
62+
# ==================forward==================
63+
out = logistic_model(x)
64+
loss = criterion(out, y)
65+
print_loss = loss.data[0]
66+
mask = out.ge(0.5).float()
67+
correct = (mask == y).sum()
68+
acc = correct.data[0] / x.size(0)
69+
# ===================backward=================
70+
optimizer.zero_grad()
71+
loss.backward()
72+
optimizer.step()
73+
if (epoch+1) % 1000 == 0:
74+
print('*'*10)
75+
print('epoch {}'.format(epoch+1))
76+
print('loss is {:.4f}'.format(print_loss))
77+
print('acc is {:.4f}'.format(acc))
78+
torch.save(logistic_model.state_dict, './logistic_regression.pth')
79+
# ====================plot classification=================
80+
w0, w1 = logistic_model.lr.weight[0]
81+
w0 = w0.data[0]
82+
w1 = w1.data[0]
83+
b = logistic_model.lr.bias.data[0]
84+
plot_x = np.arange(30, 100, 0.1)
85+
plot_y = (-w0 * plot_x - b) / w1
86+
plt.plot(plot_x, plot_y)
87+
plt.show()

0 commit comments

Comments
 (0)