Skip to content

Commit d696c71

Browse files
committed
onetime kmean is done
1 parent e6cc1ca commit d696c71

File tree

1 file changed

+81
-20
lines changed

1 file changed

+81
-20
lines changed

Mapping/kmean_clustering/kmean_clustering.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,83 @@
22
33
Object clustering with k-mean algorithm
44
5-
65
author: Atsushi Sakai (@Atsushi_twi)
76
87
"""
98

9+
import numpy as np
10+
import math
1011
import matplotlib.pyplot as plt
1112
import random
1213

1314

14-
class Cluster:
15+
class Clusters:
16+
17+
def __init__(self, x, y, nlabel):
18+
self.x = x
19+
self.y = y
20+
self.ndata = len(self.x)
21+
self.nlabel = nlabel
22+
self.labels = [random.randint(0, nlabel - 1)
23+
for _ in range(self.ndata)]
24+
self.cx = [0.0 for _ in range(nlabel)]
25+
self.cy = [0.0 for _ in range(nlabel)]
26+
27+
28+
def init_clusters(rx, ry, nc):
29+
30+
clusters = Clusters(rx, ry, nc)
31+
32+
return clusters
33+
34+
35+
def calc_centroid(clusters):
36+
37+
for ic in range(clusters.nlabel):
38+
x, y = calc_labeled_points(ic, clusters)
39+
ndata = len(x)
40+
clusters.cx[ic] = sum(x) / ndata
41+
clusters.cy[ic] = sum(y) / ndata
42+
43+
return clusters
1544

16-
def __init__(self):
17-
self.x = []
18-
self.y = []
19-
self.cx = None
20-
self.cy = None
45+
46+
def update_clusters(clusters):
47+
cost = 0.0
48+
49+
for ip in range(clusters.ndata):
50+
px = clusters.x[ip]
51+
py = clusters.y[ip]
52+
53+
dx = [icx - px for icx in clusters.cx]
54+
dy = [icy - py for icy in clusters.cy]
55+
56+
dlist = [math.sqrt(idx**2 + idy**2) for (idx, idy) in zip(dx, dy)]
57+
mind = min(dlist)
58+
min_id = dlist.index(mind)
59+
clusters.labels[ip] = min_id
60+
cost += min_id
61+
62+
return clusters, cost
2163

2264

2365
def kmean_clustering(rx, ry, nc):
2466

25-
minx, maxx = min(rx), max(rx)
26-
miny, maxy = min(ry), max(ry)
67+
clusters = init_clusters(rx, ry, nc)
68+
clusters = calc_centroid(clusters)
2769

28-
clusters = [Cluster() for i in range(nc)]
70+
MAX_LOOP = 10
71+
DCOST_TH = 1.0
72+
pcost = 100.0
73+
for loop in range(MAX_LOOP):
74+
print("Loop:", loop)
75+
clusters, cost = update_clusters(clusters)
76+
clusters = calc_centroid(clusters)
2977

30-
for c in clusters:
31-
c.cx = random.uniform(minx, maxx)
32-
c.cy = random.uniform(miny, maxy)
78+
dcost = abs(cost - pcost)
79+
if dcost < DCOST_TH:
80+
break
81+
pcost = cost
3382

3483
return clusters
3584

@@ -40,17 +89,30 @@ def calc_raw_data():
4089

4190
cx = [0.0, 5.0]
4291
cy = [0.0, 5.0]
43-
np = 30
92+
npoints = 30
4493
rand_d = 3.0
4594

4695
for (icx, icy) in zip(cx, cy):
47-
for _ in range(np):
96+
for _ in range(npoints):
4897
rx.append(icx + rand_d * (random.random() - 0.5))
4998
ry.append(icy + rand_d * (random.random() - 0.5))
5099

51100
return rx, ry
52101

53102

103+
def calc_labeled_points(ic, clusters):
104+
105+
inds = np.array([i for i in range(clusters.ndata)
106+
if clusters.labels[i] == ic])
107+
tx = np.array(clusters.x)
108+
ty = np.array(clusters.y)
109+
110+
x = tx[inds]
111+
y = ty[inds]
112+
113+
return x, y
114+
115+
54116
def main():
55117
print(__file__ + " start!!")
56118

@@ -59,11 +121,10 @@ def main():
59121
ncluster = 2
60122
clusters = kmean_clustering(rx, ry, ncluster)
61123

62-
for c in clusters:
63-
print(c.cx, c.cy)
64-
plt.plot(c.cx, c.cy, "x")
65-
66-
plt.plot(rx, ry, ".")
124+
for ic in range(clusters.nlabel):
125+
x, y = calc_labeled_points(ic, clusters)
126+
plt.plot(x, y, "x")
127+
plt.plot(clusters.cx, clusters.cy, "o")
67128
plt.show()
68129

69130

0 commit comments

Comments
 (0)