|
| 1 | +__author__ = 'sherlock' |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | +from torch.autograd import Variable |
| 6 | +import numpy as np |
| 7 | +import matplotlib.pyplot as plt |
| 8 | + |
| 9 | +torch.manual_seed(2017) |
| 10 | +# get the data and preprocessing |
| 11 | +with open('data.txt', 'r') as f: |
| 12 | + data_list = f.readlines() |
| 13 | + data_list = [i.split('\n')[0] for i in data_list] |
| 14 | + data_list = [i.split(',') for i in data_list] |
| 15 | + data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list] |
| 16 | + |
| 17 | +x0 = list(filter(lambda x: x[-1] == 0.0, data)) |
| 18 | +x1 = list(filter(lambda x: x[-1] == 1.0, data)) |
| 19 | +plot_x0_0 = [i[0] for i in x0] |
| 20 | +plot_x0_1 = [i[1] for i in x0] |
| 21 | +plot_x1_0 = [i[0] for i in x1] |
| 22 | +plot_x1_1 = [i[1] for i in x1] |
| 23 | + |
| 24 | +plt.plot(plot_x0_0, plot_x0_1, 'ro', label='x_0') |
| 25 | +plt.plot(plot_x1_0, plot_x1_1, 'bo', label='x_1') |
| 26 | +plt.legend(loc='best') |
| 27 | + |
| 28 | +# transform to tensor |
| 29 | +np_data = np.array(data, dtype=np.float32) |
| 30 | +x_data = torch.from_numpy(np_data[:, 0:2]) |
| 31 | +y_data = torch.from_numpy(np_data[:, -1]) |
| 32 | + |
| 33 | + |
| 34 | +# define logistic regression |
| 35 | +class LogisticRegression(nn.Module): |
| 36 | + def __init__(self): |
| 37 | + super(LogisticRegression, self).__init__() |
| 38 | + self.lr = nn.Linear(2, 1) |
| 39 | + self.sm = nn.Sigmoid() |
| 40 | + |
| 41 | + def forward(self, x): |
| 42 | + x = self.lr(x) |
| 43 | + x = self.sm(x) |
| 44 | + return x |
| 45 | + |
| 46 | + |
| 47 | +logistic_model = LogisticRegression() |
| 48 | +if torch.cuda.is_available(): |
| 49 | + logistic_model.cuda() |
| 50 | + |
| 51 | +criterion = nn.BCELoss() |
| 52 | +optimizer = torch.optim.SGD(logistic_model.parameters(), lr=1e-3, |
| 53 | + momentum=0.9) |
| 54 | + |
| 55 | +for epoch in range(50000): |
| 56 | + if torch.cuda.is_available(): |
| 57 | + x = Variable(x_data).cuda() |
| 58 | + y = Variable(y_data).cuda() |
| 59 | + else: |
| 60 | + x = Variable(x_data) |
| 61 | + y = Variable(y_data) |
| 62 | + # ==================forward================== |
| 63 | + out = logistic_model(x) |
| 64 | + loss = criterion(out, y) |
| 65 | + print_loss = loss.data[0] |
| 66 | + mask = out.ge(0.5).float() |
| 67 | + correct = (mask == y).sum() |
| 68 | + acc = correct.data[0] / x.size(0) |
| 69 | + # ===================backward================= |
| 70 | + optimizer.zero_grad() |
| 71 | + loss.backward() |
| 72 | + optimizer.step() |
| 73 | + if (epoch+1) % 1000 == 0: |
| 74 | + print('*'*10) |
| 75 | + print('epoch {}'.format(epoch+1)) |
| 76 | + print('loss is {:.4f}'.format(print_loss)) |
| 77 | + print('acc is {:.4f}'.format(acc)) |
| 78 | +torch.save(logistic_model.state_dict, './logistic_regression.pth') |
| 79 | +# ====================plot classification================= |
| 80 | +w0, w1 = logistic_model.lr.weight[0] |
| 81 | +w0 = w0.data[0] |
| 82 | +w1 = w1.data[0] |
| 83 | +b = logistic_model.lr.bias.data[0] |
| 84 | +plot_x = np.arange(30, 100, 0.1) |
| 85 | +plot_y = (-w0 * plot_x - b) / w1 |
| 86 | +plt.plot(plot_x, plot_y) |
| 87 | +plt.show() |
0 commit comments