Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve
  • Loading branch information
RiptideBo committed Sep 22, 2017
commit 53b6fe15c93aa4a69d1f42710c320efa7ab5ae26
26 changes: 22 additions & 4 deletions Neural_Network/neuralnetwork_bp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import matplotlib.pyplot as plt

class Bpnw():
class Bpnn():

def __init__(self,n_layer1,n_layer2,n_layer3,rate_w=0.3,rate_t=0.3):
'''
Expand Down Expand Up @@ -38,7 +38,7 @@ def sig_plain(self,x):
def do_round(self,x):
return round(x, 3)

def trian(self,patterns,data_train, data_teach, n_repeat, error_accuracy,draw_e = bool):
def trian(self,patterns,data_train, data_teach, n_repeat, error_accuracy, draw_e=False):
'''
:param patterns: the number of patterns
:param data_train: training data x; numpy.ndarray
Expand Down Expand Up @@ -127,8 +127,26 @@ def predict(self,data_test):


def main():
#I will fish the mian function later
pass
#example data
data_x = [[1,2,3,4],
[5,6,7,8],
[2,2,3,4],
[7,7,8,8]]
data_y = [[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1]]

test_x = [[1,2,3,4],
[3,2,3,4]]

#building network model
model = Bpnn(4,10,4)
#training the model
model.trian(patterns=4,data_train=data_x,data_teach=data_y,
n_repeat=100,error_accuracy=0.1,draw_e=True)
#predicting data
model.predict(test_x)

if __name__ == '__main__':
main()