Skip to content

Commit 472864c

Browse files
committed
update
1 parent d59ae9f commit 472864c

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed
File renamed without changes.
File renamed without changes.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
34.62365962451697,78.0246928153624,0
2+
30.28671076822607,43.89499752400101,0
3+
35.84740876993872,72.90219802708364,0
4+
60.18259938620976,86.30855209546826,1
5+
79.0327360507101,75.3443764369103,1
6+
45.08327747668339,56.3163717815305,0
7+
61.10666453684766,96.51142588489624,1
8+
75.02474556738889,46.55401354116538,1
9+
76.09878670226257,87.42056971926803,1
10+
84.43281996120035,43.53339331072109,1
11+
95.86155507093572,38.22527805795094,0
12+
75.01365838958247,30.60326323428011,0
13+
82.30705337399482,76.48196330235604,1
14+
69.36458875970939,97.71869196188608,1
15+
39.53833914367223,76.03681085115882,0
16+
53.9710521485623,89.20735013750205,1
17+
69.07014406283025,52.74046973016765,1
18+
67.94685547711617,46.67857410673128,0
19+
70.66150955499435,92.92713789364831,1
20+
76.97878372747498,47.57596364975532,1
21+
67.37202754570876,42.83843832029179,0
22+
89.67677575072079,65.79936592745237,1
23+
50.534788289883,48.85581152764205,0
24+
34.21206097786789,44.20952859866288,0
25+
77.9240914545704,68.9723599933059,1
26+
62.27101367004632,69.95445795447587,1
27+
80.1901807509566,44.82162893218353,1
28+
93.114388797442,38.80067033713209,0
29+
61.83020602312595,50.25610789244621,0
30+
38.78580379679423,64.99568095539578,0
31+
61.379289447425,72.80788731317097,1
32+
85.40451939411645,57.05198397627122,1
33+
52.10797973193984,63.12762376881715,0
34+
52.04540476831827,69.43286012045222,1
35+
40.23689373545111,71.16774802184875,0
36+
54.63510555424817,52.21388588061123,0
37+
33.91550010906887,98.86943574220611,0
38+
64.17698887494485,80.90806058670817,1
39+
74.78925295941542,41.57341522824434,0
40+
34.1836400264419,75.2377203360134,0
41+
83.90239366249155,56.30804621605327,1
42+
51.54772026906181,46.85629026349976,0
43+
94.44336776917852,65.56892160559052,1
44+
82.36875375713919,40.61825515970618,0
45+
51.04775177128865,45.82270145776001,0
46+
62.22267576120188,52.06099194836679,0
47+
77.19303492601364,70.45820000180959,1
48+
97.77159928000232,86.7278223300282,1
49+
62.07306379667647,96.76882412413983,1
50+
91.56497449807442,88.69629254546599,1
51+
79.94481794066932,74.16311935043758,1
52+
99.2725269292572,60.99903099844988,1
53+
90.54671411399852,43.39060180650027,1
54+
34.52451385320009,60.39634245837173,0
55+
50.2864961189907,49.80453881323059,0
56+
49.58667721632031,59.80895099453265,0
57+
97.64563396007767,68.86157272420604,1
58+
32.57720016809309,95.59854761387875,0
59+
74.24869136721598,69.82457122657193,1
60+
71.79646205863379,78.45356224515052,1
61+
75.3956114656803,85.75993667331619,1
62+
35.28611281526193,47.02051394723416,0
63+
56.25381749711624,39.26147251058019,0
64+
30.05882244669796,49.59297386723685,0
65+
44.66826172480893,66.45008614558913,0
66+
66.56089447242954,41.09209807936973,0
67+
40.45755098375164,97.53518548909936,1
68+
49.07256321908844,51.88321182073966,0
69+
80.27957401466998,92.11606081344084,1
70+
66.74671856944039,60.99139402740988,1
71+
32.72283304060323,43.30717306430063,0
72+
64.0393204150601,78.03168802018232,1
73+
72.34649422579923,96.22759296761404,1
74+
60.45788573918959,73.09499809758037,1
75+
58.84095621726802,75.85844831279042,1
76+
99.82785779692128,72.36925193383885,1
77+
47.26426910848174,88.47586499559782,1
78+
50.45815980285988,75.80985952982456,1
79+
60.45555629271532,42.50840943572217,0
80+
82.22666157785568,42.71987853716458,0
81+
88.9138964166533,69.80378889835472,1
82+
94.83450672430196,45.69430680250754,1
83+
67.31925746917527,66.58935317747915,1
84+
57.23870631569862,59.51428198012956,1
85+
80.36675600171273,90.96014789746954,1
86+
68.46852178591112,85.59430710452014,1
87+
42.0754545384731,78.84478600148043,0
88+
75.47770200533905,90.42453899753964,1
89+
78.63542434898018,96.64742716885644,1
90+
52.34800398794107,60.76950525602592,0
91+
94.09433112516793,77.15910509073893,1
92+
90.44855097096364,87.50879176484702,1
93+
55.48216114069585,35.57070347228866,0
94+
74.49269241843041,84.84513684930135,1
95+
89.84580670720979,45.35828361091658,1
96+
83.48916274498238,48.38028579728175,1
97+
42.2617008099817,87.10385094025457,1
98+
99.31500880510394,68.77540947206617,1
99+
55.34001756003703,64.9319380069486,1
100+
74.77589300092767,89.52981289513276,1
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
__author__ = 'sherlock'
2+
3+
import torch
4+
from torch import nn
5+
from torch.autograd import Variable
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
9+
torch.manual_seed(2017)
10+
# get the data and preprocessing
11+
with open('data.txt', 'r') as f:
12+
data_list = f.readlines()
13+
data_list = [i.split('\n')[0] for i in data_list]
14+
data_list = [i.split(',') for i in data_list]
15+
data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]
16+
17+
x0 = list(filter(lambda x: x[-1] == 0.0, data))
18+
x1 = list(filter(lambda x: x[-1] == 1.0, data))
19+
plot_x0_0 = [i[0] for i in x0]
20+
plot_x0_1 = [i[1] for i in x0]
21+
plot_x1_0 = [i[0] for i in x1]
22+
plot_x1_1 = [i[1] for i in x1]
23+
24+
plt.plot(plot_x0_0, plot_x0_1, 'ro', label='x_0')
25+
plt.plot(plot_x1_0, plot_x1_1, 'bo', label='x_1')
26+
plt.legend(loc='best')
27+
28+
# transform to tensor
29+
np_data = np.array(data, dtype=np.float32)
30+
x_data = torch.from_numpy(np_data[:, 0:2])
31+
y_data = torch.from_numpy(np_data[:, -1])
32+
33+
34+
# define logistic regression
35+
class LogisticRegression(nn.Module):
36+
def __init__(self):
37+
super(LogisticRegression, self).__init__()
38+
self.lr = nn.Linear(2, 1)
39+
self.sm = nn.Sigmoid()
40+
41+
def forward(self, x):
42+
x = self.lr(x)
43+
x = self.sm(x)
44+
return x
45+
46+
47+
logistic_model = LogisticRegression()
48+
if torch.cuda.is_available():
49+
logistic_model.cuda()
50+
51+
criterion = nn.BCELoss()
52+
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=1e-3,
53+
momentum=0.9)
54+
55+
for epoch in range(50000):
56+
if torch.cuda.is_available():
57+
x = Variable(x_data).cuda()
58+
y = Variable(y_data).cuda()
59+
else:
60+
x = Variable(x_data)
61+
y = Variable(y_data)
62+
# ==================forward==================
63+
out = logistic_model(x)
64+
loss = criterion(out, y)
65+
print_loss = loss.data[0]
66+
67+
# ===================backward=================
68+
optimizer.zero_grad()
69+
loss.backward()
70+
optimizer.step()
71+
if (epoch+1) % 1000 == 0:
72+
print('epoch {}'.format(epoch+1))
73+
print('loss is {:.4f}'.format(print_loss))
74+
75+
torch.save(logistic_model.state_dict, './logistic_regression.pth')
76+
# ====================plot classification=================
77+
w0, w1 = logistic_model.lr.weight[0]
78+
w0 = w0.data[0]
79+
w1 = w1.data[0]
80+
b = logistic_model.lr.bias.data[0]
81+
plot_x = np.arange(30, 100, 0.1)
82+
plot_y = (-w0 * plot_x - b) / w1
83+
plt.plot(plot_x, plot_y)
84+
plt.show()

cahpter3_MLP/pytorch-basic.py renamed to cahpter3_MLP/pytorch_basic/pytorch-basic.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22
import numpy as np
33
from torch.autograd import Variable
4+
from torch.utils.data import Dataset, DataLoader
5+
from torchvision.datasets import ImageFolder
6+
import pandas as pd
47

58
# =============================Tensor================================
69
# Define 3x2 matrix with given values
@@ -58,3 +61,40 @@
5861

5962
y.backward(torch.FloatTensor([1, 0.1, 0.01]))
6063
print(x.grad)
64+
65+
66+
# ==============================nn.Module=================================
67+
class net_name(nn.Module):
68+
def __init__(self, other_arguments):
69+
super(net_name, self).__init__()
70+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
71+
# other network layer
72+
73+
def forward(self, x):
74+
x = self.conv1(x)
75+
return x
76+
77+
# ============================Dataset====================================
78+
79+
80+
class myDataset(Dataset):
81+
def __init__(self, csv_file, txt_file, root_dir, other_file):
82+
self.csv_data = pd.read_csv(csv_file)
83+
with open(txt_file, 'r') as f:
84+
data_list = f.readlines()
85+
self.txt_data = data_list
86+
self.root_dir = root_dir
87+
88+
def __len__(self):
89+
return len(self.csv_data)
90+
91+
def __getitem__(self, idx):
92+
data = (self.csv_data[idx], self.txt_data[idx])
93+
return data
94+
95+
96+
dataiter = DataLoader(myDataset, batch_size=32, shuffle=True,
97+
collate_fn=default_collate)
98+
99+
dset = ImageFolder(root='root_path', transform=None,
100+
loader=default_loader)

0 commit comments

Comments
 (0)