Skip to content

Commit b5521ab

Browse files
committed
Cleaned config
1 parent fb4787d commit b5521ab

File tree

6 files changed

+36
-34
lines changed

6 files changed

+36
-34
lines changed

src/main/java/com/secureai/DQNMain.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import com.secureai.utils.RLStatTrainingListener;
1313
import com.secureai.utils.YAML;
1414
import org.apache.log4j.BasicConfigurator;
15+
import org.deeplearning4j.optimize.listeners.PerformanceListener;
16+
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
1517
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
1618
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteDense;
1719
import org.deeplearning4j.rl4j.network.dqn.DQN;
@@ -32,26 +34,26 @@ public static void main(String... args) throws IOException {
3234
ActionSet actionSet = YAML.parse(String.format("data/action-sets/action-set-%s.yml", argsMap.getOrDefault("actionSet", "paper")), ActionSet.class);
3335

3436
QLearning.QLConfiguration qlConfiguration = new QLearning.QLConfiguration(
35-
Integer.parseInt(argsMap.getOrDefault("seed", "123")), //Random seed
37+
Integer.parseInt(argsMap.getOrDefault("seed", "123")), //Random seed
3638
Integer.parseInt(argsMap.getOrDefault("maxEpochStep", "1000")), //Max step By epoch
37-
Integer.parseInt(argsMap.getOrDefault("maxStep", "4000")), //Max step
38-
Integer.parseInt(argsMap.getOrDefault("expRepMaxSize", "15000")), //Max size of experience replay
39-
Integer.parseInt(argsMap.getOrDefault("batchSize", "64")), //size of batches
39+
Integer.parseInt(argsMap.getOrDefault("maxStep", "40000")), //Max step
40+
Integer.parseInt(argsMap.getOrDefault("expRepMaxSize", "15000")), //Max size of experience replay
41+
Integer.parseInt(argsMap.getOrDefault("batchSize", "64")), //size of batches
4042
Integer.parseInt(argsMap.getOrDefault("targetDqnUpdateFreq", "2000")), //target update (hard)
41-
Integer.parseInt(argsMap.getOrDefault("updateStart", "10")), //num step noop warmup
42-
Double.parseDouble(argsMap.getOrDefault("rewardFactor", "1")), //reward scaling
43-
Double.parseDouble(argsMap.getOrDefault("gamma", "0.75")), //gamma
43+
Integer.parseInt(argsMap.getOrDefault("updateStart", "10")), //num step noop warmup
44+
Double.parseDouble(argsMap.getOrDefault("rewardFactor", "1")), //reward scaling
45+
Double.parseDouble(argsMap.getOrDefault("gamma", "0.75")), //gamma
4446
Double.parseDouble(argsMap.getOrDefault("errorClamp", "0.9")), //td-error clipping
45-
Float.parseFloat(argsMap.getOrDefault("minEpsilon", "0.1f")), //min epsilon
46-
Integer.parseInt(argsMap.getOrDefault("epsilonNbStep", "1000")), //num step for eps greedy anneal
47-
Boolean.parseBoolean(argsMap.getOrDefault("doubleDQN", "false")) //double DQN
47+
Float.parseFloat(argsMap.getOrDefault("minEpsilon", "0.1")), //min epsilon
48+
Integer.parseInt(argsMap.getOrDefault("epsilonNbStep", "15000")), //num step for eps greedy anneal
49+
Boolean.parseBoolean(argsMap.getOrDefault("doubleDQN", "false")) //double DQN
4850
);
4951

5052
SystemEnvironment mdp = new SystemEnvironment(topology, actionSet);
5153
FilteredMultiLayerNetwork nn = new NNBuilder().build(mdp.getObservationSpace().size(), mdp.getActionSpace().getSize(), Integer.parseInt(argsMap.getOrDefault("layers", "3")));
5254
nn.setMultiLayerNetworkPredictionFilter(input -> mdp.getActionSpace().actionsMask(input));
53-
//nn.setListeners(new ScoreIterationListener(100));
54-
//nn.setListeners(new PerformanceListener(1, true, true));
55+
nn.setListeners(new ScoreIterationListener(100));
56+
nn.setListeners(new PerformanceListener(1, true, true));
5557
System.out.println(nn.summary());
5658

5759
String dqnType = argsMap.getOrDefault("dqn", "standard");

src/main/java/com/secureai/DynDQNMain.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,19 @@ public static void setup() {
115115
ActionSet actionSet = YAML.parse(String.format("data/action-sets/action-set-%s.yml", actionSetId), ActionSet.class);
116116

117117
QLearning.QLConfiguration qlConfiguration = new QLearning.QLConfiguration(
118-
Integer.parseInt(argsMap.getOrDefault("seed", "123")), //Random seed
119-
Integer.parseInt(argsMap.getOrDefault("maxEpochStep", "200")), //Max step By epoch
120-
Integer.parseInt(argsMap.getOrDefault("maxStep", "8000")), //Max step
121-
Integer.parseInt(argsMap.getOrDefault("expRepMaxSize", "150000")), //Max size of experience replay
122-
Integer.parseInt(argsMap.getOrDefault("batchSize", "32")), //size of batches
123-
Integer.parseInt(argsMap.getOrDefault("targetDqnUpdateFreq", "500")), //target update (hard)
124-
Integer.parseInt(argsMap.getOrDefault("updateStart", "10")), //num step noop warmup
125-
Double.parseDouble(argsMap.getOrDefault("rewardFactor", "1")), //reward scaling
126-
Double.parseDouble(argsMap.getOrDefault("gamma", "5")), //gamma
127-
Double.parseDouble(argsMap.getOrDefault("errorClamp", ".8")), //td-error clipping
128-
Float.parseFloat(argsMap.getOrDefault("minEpsilon", "0.1f")), //min epsilon
129-
Integer.parseInt(argsMap.getOrDefault("epsilonNbStep", "10000")), //num step for eps greedy anneal
130-
Boolean.parseBoolean(argsMap.getOrDefault("doubleDQN", "false")) //double DQN
118+
Integer.parseInt(argsMap.getOrDefault("seed", "123")), //Random seed
119+
Integer.parseInt(argsMap.getOrDefault("maxEpochStep", "1000")), //Max step By epoch
120+
Integer.parseInt(argsMap.getOrDefault("maxStep", "40000")), //Max step
121+
Integer.parseInt(argsMap.getOrDefault("expRepMaxSize", "15000")), //Max size of experience replay
122+
Integer.parseInt(argsMap.getOrDefault("batchSize", "64")), //size of batches
123+
Integer.parseInt(argsMap.getOrDefault("targetDqnUpdateFreq", "2000")), //target update (hard)
124+
Integer.parseInt(argsMap.getOrDefault("updateStart", "10")), //num step noop warmup
125+
Double.parseDouble(argsMap.getOrDefault("rewardFactor", "1")), //reward scaling
126+
Double.parseDouble(argsMap.getOrDefault("gamma", "0.75")), //gamma
127+
Double.parseDouble(argsMap.getOrDefault("errorClamp", "0.9")), //td-error clipping
128+
Float.parseFloat(argsMap.getOrDefault("minEpsilon", "0.1")), //min epsilon
129+
Integer.parseInt(argsMap.getOrDefault("epsilonNbStep", "15000")), //num step for eps greedy anneal
130+
Boolean.parseBoolean(argsMap.getOrDefault("doubleDQN", "false")) //double DQN
131131
);
132132

133133
SystemEnvironment newMdp = new SystemEnvironment(topology, actionSet);

src/main/java/com/secureai/QNMain.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ public static void main(String... args) throws IOException {
2727

2828
QLearning.QNConfiguration qnConfiguration = new QLearning.QNConfiguration(
2929
Integer.parseInt(argsMap.getOrDefault("seed", "123")), //Random seed
30-
Integer.parseInt(argsMap.getOrDefault("episodes", "80000")), //episodes
31-
Integer.parseInt(argsMap.getOrDefault("maxEpochStep", "128")), //max step
30+
Integer.parseInt(argsMap.getOrDefault("episodes", "40000")), //episodes
31+
Integer.parseInt(argsMap.getOrDefault("maxEpisodeStep", "400")), //max step
3232
Double.parseDouble(argsMap.getOrDefault("learningRate", "0.9")), //alpha
3333
Double.parseDouble(argsMap.getOrDefault("discountFactor", "0.75")), //gamma
34-
Float.parseFloat(argsMap.getOrDefault("minEpsilon", "0.1f")), //min epsilon
35-
Integer.parseInt(argsMap.getOrDefault("epsilonNbStep", "50000")) //num step for eps greedy anneal
34+
Float.parseFloat(argsMap.getOrDefault("minEpsilon", "0.1")), //min epsilon
35+
Integer.parseInt(argsMap.getOrDefault("epsilonNbStep", "15000")) //num step for eps greedy anneal
3636
);
3737

3838
FilteredDynamicQTable qTable = new FilteredDynamicQTable(mdp.getActionSpace().getSize());

src/main/java/com/secureai/VIMain.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ public static void main(String... args) throws IOException {
2727

2828
ValueIteration.VIConfiguration viConfiguration = new ValueIteration.VIConfiguration(
2929
Integer.parseInt(argsMap.getOrDefault("seed", "123")), //Random seed
30-
Integer.parseInt(argsMap.getOrDefault("iterations", "10")), //iterations
31-
Double.parseDouble(argsMap.getOrDefault("gamma", ".5")), //gamma
30+
Integer.parseInt(argsMap.getOrDefault("iterations", "10")), //iterations
31+
Double.parseDouble(argsMap.getOrDefault("gamma", "0.5")), //gamma
3232
Double.parseDouble(argsMap.getOrDefault("epsilon", "1e-5")) //epsilon
3333
);
3434

src/main/java/com/secureai/rl/qn/QLearning.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void train() {
6565
O state = this.mdp.reset();
6666
double rewards = 0;
6767
int j = 0;
68-
for (; j < this.conf.maxEpochStep && !this.mdp.isDone(); j++) { // batches
68+
for (; j < this.conf.maxEpisodeStep && !this.mdp.isDone(); j++) { // batches
6969
StepReply<O> step = this.trainStep(state);
7070
state = step.getObservation();
7171

@@ -108,7 +108,7 @@ public double evaluate(int episodes) {
108108
public static class QNConfiguration {
109109
int seed;
110110
int episodes;
111-
int maxEpochStep;
111+
int maxEpisodeStep;
112112
double learningRate;
113113
double discountFactor;
114114
float minEpsilon;

src/main/java/com/secureai/system/SystemRewardFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public SystemRewardFunction(SystemEnvironment environment) {
2323

2424
@Override
2525
public double reward(SystemState oldState, SystemAction systemAction, SystemState currentState) {
26-
if (oldState.equals(currentState)) return -2;
26+
if (oldState.equals(currentState)) return -2; // This is the reward if the policy choose an action that cannot be run or keep the system in the same state
2727
Action action = this.environment.getActionSet().getActions().get(systemAction.getActionId());
2828
return -(Config.TIME_WEIGHT * (action.getExecutionTime() / this.maxExecutionTime) + Config.COST_WEIGHT * (action.getExecutionCost() / this.maxExecutionCost));
2929
}

0 commit comments

Comments
 (0)