1212import com .secureai .utils .RLStatTrainingListener ;
1313import com .secureai .utils .YAML ;
1414import org .apache .log4j .BasicConfigurator ;
15+ import org .deeplearning4j .optimize .listeners .PerformanceListener ;
16+ import org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
1517import org .deeplearning4j .rl4j .learning .sync .qlearning .QLearning ;
1618import org .deeplearning4j .rl4j .learning .sync .qlearning .discrete .QLearningDiscreteDense ;
1719import org .deeplearning4j .rl4j .network .dqn .DQN ;
@@ -32,26 +34,26 @@ public static void main(String... args) throws IOException {
3234 ActionSet actionSet = YAML .parse (String .format ("data/action-sets/action-set-%s.yml" , argsMap .getOrDefault ("actionSet" , "paper" )), ActionSet .class );
3335
3436 QLearning .QLConfiguration qlConfiguration = new QLearning .QLConfiguration (
35- Integer .parseInt (argsMap .getOrDefault ("seed" , "123" )), //Random seed
37+ Integer .parseInt (argsMap .getOrDefault ("seed" , "123" )), //Random seed
3638 Integer .parseInt (argsMap .getOrDefault ("maxEpochStep" , "1000" )), //Max step By epoch
37- Integer .parseInt (argsMap .getOrDefault ("maxStep" , "4000 " )), //Max step
38- Integer .parseInt (argsMap .getOrDefault ("expRepMaxSize" , "15000" )), //Max size of experience replay
39- Integer .parseInt (argsMap .getOrDefault ("batchSize" , "64" )), //size of batches
39+ Integer .parseInt (argsMap .getOrDefault ("maxStep" , "40000 " )), //Max step
40+ Integer .parseInt (argsMap .getOrDefault ("expRepMaxSize" , "15000" )), //Max size of experience replay
41+ Integer .parseInt (argsMap .getOrDefault ("batchSize" , "64" )), //size of batches
4042 Integer .parseInt (argsMap .getOrDefault ("targetDqnUpdateFreq" , "2000" )), //target update (hard)
41- Integer .parseInt (argsMap .getOrDefault ("updateStart" , "10" )), //num step noop warmup
42- Double .parseDouble (argsMap .getOrDefault ("rewardFactor" , "1" )), //reward scaling
43- Double .parseDouble (argsMap .getOrDefault ("gamma" , "0.75" )), //gamma
43+ Integer .parseInt (argsMap .getOrDefault ("updateStart" , "10" )), //num step noop warmup
44+ Double .parseDouble (argsMap .getOrDefault ("rewardFactor" , "1" )), //reward scaling
45+ Double .parseDouble (argsMap .getOrDefault ("gamma" , "0.75" )), //gamma
4446 Double .parseDouble (argsMap .getOrDefault ("errorClamp" , "0.9" )), //td-error clipping
45- Float .parseFloat (argsMap .getOrDefault ("minEpsilon" , "0.1f " )), //min epsilon
46- Integer .parseInt (argsMap .getOrDefault ("epsilonNbStep" , "1000 " )), //num step for eps greedy anneal
47- Boolean .parseBoolean (argsMap .getOrDefault ("doubleDQN" , "false" )) //double DQN
47+ Float .parseFloat (argsMap .getOrDefault ("minEpsilon" , "0.1 " )), //min epsilon
48+ Integer .parseInt (argsMap .getOrDefault ("epsilonNbStep" , "15000 " )), //num step for eps greedy anneal
49+ Boolean .parseBoolean (argsMap .getOrDefault ("doubleDQN" , "false" )) //double DQN
4850 );
4951
5052 SystemEnvironment mdp = new SystemEnvironment (topology , actionSet );
5153 FilteredMultiLayerNetwork nn = new NNBuilder ().build (mdp .getObservationSpace ().size (), mdp .getActionSpace ().getSize (), Integer .parseInt (argsMap .getOrDefault ("layers" , "3" )));
5254 nn .setMultiLayerNetworkPredictionFilter (input -> mdp .getActionSpace ().actionsMask (input ));
53- // nn.setListeners(new ScoreIterationListener(100));
54- // nn.setListeners(new PerformanceListener(1, true, true));
55+ nn .setListeners (new ScoreIterationListener (100 ));
56+ nn .setListeners (new PerformanceListener (1 , true , true ));
5557 System .out .println (nn .summary ());
5658
5759 String dqnType = argsMap .getOrDefault ("dqn" , "standard" );
0 commit comments