Skip to content

Commit d2774d6

Browse files
committed
RBM.go
1 parent ee7e7b6 commit d2774d6

File tree

2 files changed

+228
-0
lines changed

2 files changed

+228
-0
lines changed

go/RBM.go

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"math/rand"
6+
u "./utils"
7+
)
8+
9+
type RBM struct {
10+
N int
11+
n_visible int
12+
n_hidden int
13+
W [][]float64
14+
hbias []float64
15+
vbias []float64
16+
}
17+
18+
19+
func RBM__construct(this *RBM, N int, n_visible int, n_hidden int, W [][]float64, hbias []float64, vbias []float64) {
20+
a := 1.0 / float64(n_visible)
21+
22+
this.N = N
23+
this.n_visible = n_visible
24+
this.n_hidden = n_hidden
25+
26+
if W == nil {
27+
this.W = make([][]float64, n_hidden)
28+
for i := 0; i < n_hidden; i++ { this.W[i] = make([]float64, n_visible) }
29+
30+
for i := 0; i < n_hidden; i++ {
31+
for j := 0; j < n_visible; j++ {
32+
this.W[i][j] = u.Uniform(-a, a)
33+
}
34+
}
35+
} else {
36+
this.W = W
37+
}
38+
39+
if hbias == nil {
40+
this.hbias = make([]float64, n_hidden)
41+
} else {
42+
this.hbias = hbias
43+
}
44+
45+
if vbias == nil {
46+
this.vbias = make([]float64, n_visible)
47+
} else {
48+
this.vbias = vbias
49+
}
50+
}
51+
52+
func RBM_contrastive_divergence(this *RBM, input []int, lr float64, k int) {
53+
ph_mean := make([]float64, this.n_hidden)
54+
ph_sample := make([]int, this.n_hidden)
55+
nv_means := make([]float64, this.n_visible)
56+
nv_samples := make([]int, this.n_visible)
57+
nh_means := make([]float64, this.n_hidden)
58+
nh_samples := make([]int, this.n_hidden)
59+
60+
/* CD-k */
61+
RBM_sample_h_given_v(this, input, ph_mean, ph_sample)
62+
63+
for step := 0; step < k; step++ {
64+
if step == 0 {
65+
RBM_gibbs_hvh(this, ph_sample, nv_means, nv_samples, nh_means, nh_samples)
66+
} else {
67+
RBM_gibbs_hvh(this, nh_samples, nv_means, nv_samples, nh_means, nh_samples)
68+
}
69+
}
70+
71+
for i := 0; i < this.n_hidden; i++ {
72+
for j := 0; j < this.n_visible; j++ {
73+
this.W[i][j] += lr * (ph_mean[i] * float64(input[j]) - nh_means[i] * float64(nv_samples[j])) / float64(this.N)
74+
}
75+
this.hbias[i] += lr * (float64(ph_sample[i]) - nh_means[i]) / float64(this.N)
76+
}
77+
78+
for i := 0; i < this.n_visible; i++ {
79+
this.vbias[i] += lr * float64(input[i] - nv_samples[i]) / float64(this.N)
80+
}
81+
}
82+
83+
func RBM_sample_h_given_v(this *RBM, v0_sample []int, mean []float64, sample []int) {
84+
for i := 0; i < this.n_hidden; i++ {
85+
mean[i] = RBM_propup(this, v0_sample, this.W[i], this.hbias[i])
86+
sample[i] = u.Binomial(1, mean[i])
87+
}
88+
}
89+
90+
func RBM_sample_v_given_h(this *RBM, h0_sample []int, mean []float64, sample []int) {
91+
for i := 0; i < this.n_visible; i++ {
92+
mean[i] = RBM_propdown(this, h0_sample, i, this.vbias[i])
93+
sample[i] = u.Binomial(1, mean[i])
94+
}
95+
}
96+
97+
func RBM_propup(this *RBM, v []int, w []float64, b float64) float64 {
98+
pre_sigmoid_activation := 0.0
99+
100+
for j := 0; j < this.n_visible; j++ {
101+
pre_sigmoid_activation += w[j] * float64(v[j])
102+
}
103+
pre_sigmoid_activation += b
104+
105+
return u.Sigmoid(pre_sigmoid_activation)
106+
}
107+
108+
func RBM_propdown(this *RBM, h []int, i int, b float64) float64 {
109+
pre_sigmoid_activation := 0.0
110+
111+
for j := 0; j < this.n_hidden; j++ {
112+
pre_sigmoid_activation += this.W[j][i] * float64(h[j])
113+
}
114+
pre_sigmoid_activation += b
115+
116+
return u.Sigmoid(pre_sigmoid_activation)
117+
}
118+
119+
func RBM_gibbs_hvh(this *RBM, h0_sample []int, nv_means []float64, nv_samples []int, nh_means []float64, nh_samples []int) {
120+
RBM_sample_v_given_h(this, h0_sample, nv_means, nv_samples)
121+
RBM_sample_h_given_v(this, nv_samples, nh_means, nh_samples)
122+
}
123+
124+
func RBM_reconstruct(this *RBM, v []int, reconstructed_v []float64) {
125+
h := make([]float64, this.n_hidden)
126+
var pre_sigmoid_activation float64
127+
128+
for i := 0; i < this.n_hidden; i++ {
129+
h[i] = RBM_propup(this, v, this.W[i], this.hbias[i])
130+
}
131+
132+
for i := 0; i < this.n_visible; i++ {
133+
pre_sigmoid_activation = 0.0
134+
for j := 0; j < this.n_hidden; j++ {
135+
pre_sigmoid_activation += this.W[j][i] * h[j]
136+
}
137+
pre_sigmoid_activation += this.vbias[i]
138+
139+
reconstructed_v[i] = u.Sigmoid(pre_sigmoid_activation)
140+
}
141+
}
142+
143+
144+
func test_rbm() {
145+
rand.Seed(0)
146+
147+
learning_rate := 0.1
148+
training_epochs := 1000
149+
k := 1
150+
151+
train_N := 6
152+
test_N := 2
153+
n_visible := 6
154+
n_hidden := 3
155+
156+
// training data
157+
train_X := [][]int {
158+
{1, 1, 1, 0, 0, 0},
159+
{1, 0, 1, 0, 0, 0},
160+
{1, 1, 1, 0, 0, 0},
161+
{0, 0, 1, 1, 1, 0},
162+
{0, 0, 1, 0, 1, 0},
163+
{0, 0, 1, 1, 1, 0},
164+
}
165+
166+
167+
// construct RBM
168+
var rbm RBM
169+
RBM__construct(&rbm, train_N, n_visible, n_hidden, nil, nil, nil)
170+
171+
// train
172+
for epoch := 0; epoch < training_epochs; epoch++ {
173+
for i := 0; i < train_N; i++ {
174+
RBM_contrastive_divergence(&rbm, train_X[i], learning_rate, k)
175+
}
176+
}
177+
178+
// test data
179+
test_X := [][]int {
180+
{1, 1, 0, 0, 0, 0},
181+
{0, 0, 0, 1, 1, 0},
182+
}
183+
reconstructed_X := make([][]float64, test_N)
184+
for i := 0; i < test_N; i++ { reconstructed_X[i] = make([]float64, n_visible)}
185+
186+
187+
// test
188+
for i := 0; i < test_N; i++ {
189+
RBM_reconstruct(&rbm, test_X[i], reconstructed_X[i])
190+
for j := 0; j < n_visible; j++ {
191+
fmt.Printf("%.5f ", reconstructed_X[i][j])
192+
}
193+
fmt.Printf("\n")
194+
}
195+
}
196+
197+
198+
func main() {
199+
test_rbm()
200+
}

go/utils/utils.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package utils
2+
3+
import (
4+
"math"
5+
"math/rand"
6+
)
7+
8+
func Uniform(min float64, max float64) float64 {
9+
return rand.Float64() * (max - min) + min
10+
}
11+
12+
func Binomial(n int, p float64) int {
13+
if p < 0 || p > 1 { return 0 }
14+
15+
c := 0
16+
var r float64
17+
18+
for i := 0; i < n; i++ {
19+
r = rand.Float64()
20+
if r < p { c++ }
21+
}
22+
23+
return c
24+
}
25+
26+
func Sigmoid(x float64) float64 {
27+
return 1.0 / (1.0 + math.Exp(-x))
28+
}

0 commit comments

Comments
 (0)