From 1b50dd9795606a7a71613285e86ad661f151ff42 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 14:48:37 +0100 Subject: [PATCH 01/22] q learning changes --- .../explorers/epsilon_greedy_explorer.jl | 15 ++--- .../test/policies/policies.jl | 2 +- .../test/policies/q_based_policy.jl | 55 +++++++++++++++++++ 3 files changed, 64 insertions(+), 8 deletions(-) create mode 100644 src/ReinforcementLearningCore/test/policies/q_based_policy.jl diff --git a/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl b/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl index 22c387487..f2d9ce559 100644 --- a/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl +++ b/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl @@ -99,13 +99,13 @@ get_ϵ(s::EpsilonGreedyExplorer) = get_ϵ(s, s.step) `NaN` will be filtered unless all the values are `NaN`. In that case, a random one will be returned. """ -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::Vector{I}) where {I<:Real} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? rand(s.rng, find_all_max(values)[2]) : rand(s.rng, 1:length(values)) end -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}) where {I<:Real} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? findmax(values)[2] : rand(s.rng, 1:length(values)) @@ -113,17 +113,18 @@ end ##### -RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, x, mask::Trues) = RLBase.plan!(s, x) +RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, x::A, mask::Trues) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} = RLBase.plan!(s, x) -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A, mask::M) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}, M<:Union{BitVector, Vector{Bool}}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? rand(s.rng, find_all_max(values, mask)[2]) : rand(s.rng, findall(mask)) end -RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::Vector{I}, mask::Trues) where{I<:Real} = RLBase.plan!(s, x) -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}} +RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::A, mask::Trues) where{I<:Real, A<:Union{Vector{I}, SubArray{I}}} = RLBase.plan!(s, x) + +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A, mask::M) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}, M<:Union{BitVector, Vector{Bool}}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? findmax_masked(values, mask)[2] : rand(s.rng, findall(mask)) @@ -137,7 +138,7 @@ end Return the probability of selecting each action given the estimated `values` of each action. """ -function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,true}, values) +function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} ϵ, n = get_ϵ(s), length(values) probs = fill(ϵ / n, n) max_val_inds = find_all_max(values)[2] diff --git a/src/ReinforcementLearningCore/test/policies/policies.jl b/src/ReinforcementLearningCore/test/policies/policies.jl index 27d12ca7e..417358b1a 100644 --- a/src/ReinforcementLearningCore/test/policies/policies.jl +++ b/src/ReinforcementLearningCore/test/policies/policies.jl @@ -1,4 +1,4 @@ include("agent.jl") include("multi_agent.jl") include("learners/learners.jl") - +include("q_based_policy.jl") diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl new file mode 100644 index 000000000..166f0afcb --- /dev/null +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -0,0 +1,55 @@ +@testset "QBasedPolicy" begin + + @testset "constructor" begin + q_approx = TabularQApproximator(n_state = 5, n_action = 10) + explorer = EpsilonGreedyExplorer(0.1) + p = QBasedPolicy(q_approx, explorer) + @test p.learner == q_approx + @test p.explorer == explorer + end + + @testset "plan!" begin + @testset "plan! without player argument" begin + env = TicTacToeEnv() + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + explorer = EpsilonGreedyExplorer(0.1) + policy = QBasedPolicy(q_approx, explorer) + @test 1 <= RLBase.plan!(policy, env) <= 9 + end + + @testset "plan! with player argument" begin + env = TicTacToeEnv() + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + explorer = EpsilonGreedyExplorer(0.1) + policy = QBasedPolicy(q_approx, explorer) + player = :player1 + @test 1 <= RLBase.plan!(policy, env) <= 9 + end + end + + # Test prob function + @testset "prob" begin + env = TicTacToeEnv() + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + explorer = EpsilonGreedyExplorer(0.1) + policy = QBasedPolicy(q_approx, explorer) + # prob = RLBase.prob(p, env) + # Add assertions here + end + + # Test optimise! function + @testset "optimise!" begin + env = TicTacToeEnv() + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + explorer = EpsilonGreedyExplorer(0.1) + policy = QBasedPolicy(q_approx, explorer) + s = PreActStage() + trajectory = Trajectory( + CircularArraySARTSTraces(; capacity = 1), + BatchSampler(1), + InsertSampleRatioController(n_inserted = -1), + ) + RLBase.optimise!(policy, s, trajectory) + # Add assertions here + end +end From 7b16d912b4fb52e312b839c88a37dff821b4a285 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 15:42:37 +0100 Subject: [PATCH 02/22] init td work --- src/ReinforcementLearningCore/Project.toml | 4 +- src/ReinforcementLearningFarm/Project.toml | 6 +- src/ReinforcementLearningFarm/README.md | 2 +- .../src/ReinforcementLearningFarm.jl | 2 + .../src/algorithms/algorithms.jl | 1 + .../src/algorithms/tabular/tabular.jl | 1 + .../src/algorithms/tabular/td_learner.jl | 78 +++++++++++++++++++ .../test/algorithms/algorithms.jl | 1 + .../test/algorithms/tabular/tabular.jl | 1 + .../test/algorithms/tabular/td_learner.jl | 41 ++++++++++ .../test/runtests.jl | 7 +- 11 files changed, 140 insertions(+), 4 deletions(-) create mode 100644 src/ReinforcementLearningFarm/src/algorithms/algorithms.jl create mode 100644 src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl create mode 100644 src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl create mode 100644 src/ReinforcementLearningFarm/test/algorithms/algorithms.jl create mode 100644 src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl create mode 100644 src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl diff --git a/src/ReinforcementLearningCore/Project.toml b/src/ReinforcementLearningCore/Project.toml index 04a3b01ee..f5e916fbe 100644 --- a/src/ReinforcementLearningCore/Project.toml +++ b/src/ReinforcementLearningCore/Project.toml @@ -37,6 +37,7 @@ Metal = "1.0" ProgressMeter = "1" Reexport = "1" ReinforcementLearningBase = "0.12" +ReinforcementLearningFarm = "0.0.1" ReinforcementLearningTrajectories = "0.3.7" Statistics = "1" StatsBase = "0.32, 0.33, 0.34" @@ -52,9 +53,10 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" +ReinforcementLearningFarm = "14eff660-7080-4cec-bba2-cfb12cd77ac3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "Test", "UUIDs"] +test = ["CommonRLInterface", "CUDA", "cuDNN", "DomainSets", "Metal", "Preferences", "ReinforcementLearningEnvironments", "ReinforcementLearningFarm", "Test", "UUIDs"] diff --git a/src/ReinforcementLearningFarm/Project.toml b/src/ReinforcementLearningFarm/Project.toml index ba93671d8..e72610565 100644 --- a/src/ReinforcementLearningFarm/Project.toml +++ b/src/ReinforcementLearningFarm/Project.toml @@ -3,22 +3,26 @@ uuid = "14eff660-7080-4cec-bba2-cfb12cd77ac3" version = "0.0.1" [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6" [compat] ReinforcementLearningBase = "0.12" ReinforcementLearningCore = "0.14" +ReinforcementLearningEnvironments = "0.8" julia = "1.9" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Preferences = "21216c6a-2e73-6563-6e65-726566657250" +ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["CUDA", "Metal", "Preferences", "Test", "UUIDs", "cuDNN"] +test = ["CUDA", "Metal", "Preferences", "ReinforcementLearningEnvironments", "Test", "UUIDs", "cuDNN"] diff --git a/src/ReinforcementLearningFarm/README.md b/src/ReinforcementLearningFarm/README.md index 8c59fc295..8d23a3161 100644 --- a/src/ReinforcementLearningFarm/README.md +++ b/src/ReinforcementLearningFarm/README.md @@ -2,4 +2,4 @@ This project contains updated, tested algorithms compatible with ReinforcementLearning.jl v0.11+. -Unlike ReinforcementLearningZoo, the algorithms here have been domesticated. +Unlike `ReinforcementLearningZoo`, the algorithms here have been domesticated. diff --git a/src/ReinforcementLearningFarm/src/ReinforcementLearningFarm.jl b/src/ReinforcementLearningFarm/src/ReinforcementLearningFarm.jl index 8e22f7e58..7281421b4 100644 --- a/src/ReinforcementLearningFarm/src/ReinforcementLearningFarm.jl +++ b/src/ReinforcementLearningFarm/src/ReinforcementLearningFarm.jl @@ -5,4 +5,6 @@ using ReinforcementLearningCore const RLFarm = ReinforcementLearningFarm export RLFarm +include("algorithms/algorithms.jl") + end # module diff --git a/src/ReinforcementLearningFarm/src/algorithms/algorithms.jl b/src/ReinforcementLearningFarm/src/algorithms/algorithms.jl new file mode 100644 index 000000000..48fb131b9 --- /dev/null +++ b/src/ReinforcementLearningFarm/src/algorithms/algorithms.jl @@ -0,0 +1 @@ +include("tabular/tabular.jl") diff --git a/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl b/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl new file mode 100644 index 000000000..375c44b32 --- /dev/null +++ b/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl @@ -0,0 +1 @@ +include("td_learner.jl") diff --git a/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl new file mode 100644 index 000000000..f6d4b4390 --- /dev/null +++ b/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl @@ -0,0 +1,78 @@ +export TDLearner + +using LinearAlgebra: dot +using Distributions: pdf +import Base.push! + +using ReinforcementLearningCore: AbstractLearner, TabularApproximator + +""" + TDLearner(;approximator, γ=1.0, method, n=0) + +Use temporal-difference method to estimate state value or state-action value. + +# Fields +- `approximator` is `<:TabularApproximator`. +- `γ=1.0`, discount rate. +- `method`: only `:SARS` (Q-learning) is supported for the time being. +- `n=0`: the number of time steps used minus 1. +""" +Base.@kwdef mutable struct TDLearner{M,A} <: AbstractLearner where {A<:TabularApproximator,M<:Symbol} + approximator::A + γ::Float64 = 1.0 + n::Int = 0 + + function TDLearner(approximator::A, method::Symbol; γ=1.0, n=0) where {A<:TabularApproximator} + if method ∉ [:SARS] + @error "Method $method is not supported" + else + new{method, A}(approximator, γ, n) + end + end +end + +RLCore.forward(L::TDLearner, s) = RLCore.forward(L.approximator, s) +RLCore.forward(L::TDLearner, s, a) = RLCore.forward(L.approximator, s, a) + +Q(app::TabularApproximator, s, a) = RLCore.forward(app, s, a) +Q(app::TabularApproximator, s) = RLCore.forward(app, s) + +""" + Q!(app::TabularApproximator, s::Int, s_plus_one::Int, a::Int, α::Float64, π_::Float64, γ::Float64) + +Update the Q-value of the given state-action pair. +""" +function Q!( + app::TabularApproximator, + s::I1, + s_plus_one::I2, + a::I3, + α::F1, + π_::F2, + γ::Float64, +) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat,F2<:AbstractFloat} + # Q-learning formula following https://github.com/JuliaPOMDP/TabularTDLearning.jl/blob/25c4d3888e178c51ed1ff448f36b0fcaf7c1d8e8/src/q_learn.jl#LL63C26-L63C95 + q_value_updated = α * (π_ + γ * maximum(Q(app, s_plus_one)) - Q(app, s, a)) + app.model[a, s] += q_value_updated + return Q(app, s, a) +end + +function _optimise!( + n::I1, + γ::F, + app::Approximator{Ar}, + s::I2, + s_next::I2, + a::I3, + r::F, +) where {I1<:Number,I2<:Number,I3<:Number,Ar<:AbstractArray,F<:AbstractFloat} + α = app.optimiser_state.eta + Q!(app, s, s_next, a, α, r, γ) +end + +function RLBase.optimise!( + L::TDLearner, + t::@NamedTuple{state::I1, next_state::I1, action::I2, reward::F2, terminal::Bool}, +) where {I1<:Number,I2<:Number,F2<:AbstractFloat} + _optimise!(L.n, L.γ, L.approximator, t.state, t.next_state, t.action, t.reward) +end diff --git a/src/ReinforcementLearningFarm/test/algorithms/algorithms.jl b/src/ReinforcementLearningFarm/test/algorithms/algorithms.jl new file mode 100644 index 000000000..48fb131b9 --- /dev/null +++ b/src/ReinforcementLearningFarm/test/algorithms/algorithms.jl @@ -0,0 +1 @@ +include("tabular/tabular.jl") diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl new file mode 100644 index 000000000..375c44b32 --- /dev/null +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl @@ -0,0 +1 @@ +include("td_learner.jl") diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl new file mode 100644 index 000000000..2affa7b14 --- /dev/null +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl @@ -0,0 +1,41 @@ +using Test +using Flux + +@testset "Test TDLearner creation" begin + approximator = TabularVApproximator(n_state=5) + @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner + + approximator = TabularQApproximator(n_state=5, n_action=3) + @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner +end + +# Test TDLearner struct +@testset "TDLearner struct" begin + approximator = TabularQApproximator(n_state=5, n_action=3) + learner = TDLearner(approximator, :SARS) + @test learner.approximator === approximator + @test learner.γ == 1.0 + @test learner.n == 0 +end + +# Test Q! function +@testset "Q! function" begin + approximator = TabularQApproximator(n_state=5, n_action=3) + s = 1 + s_plus_one = 2 + a = 3 + α = 0.1 + π_ = 0.5 + γ = 0.9 + RLFarm.Q!(approximator, s, s_plus_one, a, α, π_, γ) + @test RLFarm.Q(approximator, s, a) ≈ α * (π_ + γ * maximum(RLFarm.Q(approximator, s_plus_one)) - RLFarm.Q(approximator, s, a)) +end + +# Test optimise! function +@testset "optimise! function" begin + approximator = TabularQApproximator(n_state=5, n_action=3) + learner = TDLearner(approximator, :SARS) + t = (state=1, next_state=2, action=3, reward=0.5, terminal=false) + optimise!(learner, t) + @test true # Add your own assertion here +end diff --git a/src/ReinforcementLearningFarm/test/runtests.jl b/src/ReinforcementLearningFarm/test/runtests.jl index e3819a6b7..6e4db5d12 100644 --- a/src/ReinforcementLearningFarm/test/runtests.jl +++ b/src/ReinforcementLearningFarm/test/runtests.jl @@ -12,6 +12,11 @@ else end using Test -@testset "ReinforcementLearningZoo.jl" begin +using ReinforcementLearningBase +using ReinforcementLearningCore +using ReinforcementLearningEnvironments +using ReinforcementLearningFarm +@testset "ReinforcementLearningFarm.jl" begin + include("algorithms/algorithms.jl") end From d002258e2dd65a0ccab6236772142d7ecabb0db2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 17:28:19 +0100 Subject: [PATCH 03/22] qlearning passes first check --- .../src/algorithms/tabular/td_learner.jl | 25 +++++++++++-------- .../test/algorithms/tabular/td_learner.jl | 16 ++++++++---- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl index f6d4b4390..9d4ed0e3d 100644 --- a/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl @@ -5,6 +5,7 @@ using Distributions: pdf import Base.push! using ReinforcementLearningCore: AbstractLearner, TabularApproximator +using Flux """ TDLearner(;approximator, γ=1.0, method, n=0) @@ -43,31 +44,33 @@ Q(app::TabularApproximator, s) = RLCore.forward(app, s) Update the Q-value of the given state-action pair. """ function Q!( - app::TabularApproximator, + approx::TabularApproximator, s::I1, s_plus_one::I2, a::I3, - α::F1, - π_::F2, - γ::Float64, -) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat,F2<:AbstractFloat} + π_::F1, + γ::Float64, # discount factor +) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat} # Q-learning formula following https://github.com/JuliaPOMDP/TabularTDLearning.jl/blob/25c4d3888e178c51ed1ff448f36b0fcaf7c1d8e8/src/q_learn.jl#LL63C26-L63C95 - q_value_updated = α * (π_ + γ * maximum(Q(app, s_plus_one)) - Q(app, s, a)) - app.model[a, s] += q_value_updated - return Q(app, s, a) + # Terminology following https://en.wikipedia.org/wiki/Q-learning + estimate_optimal_future_value = maximum(Q(approx, s_plus_one)) + current_value = Q(approx, s, a) + raw_q_value = (π_ + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here + q_value_updated = Flux.Optimise.apply!(approx.optimiser_state, :learning, [raw_q_value])[] # adust according to optimiser learning rate + approx.model[a, s] += q_value_updated + return Q(approx, s, a) end function _optimise!( n::I1, γ::F, - app::Approximator{Ar}, + approx::Approximator{Ar}, s::I2, s_next::I2, a::I3, r::F, ) where {I1<:Number,I2<:Number,I3<:Number,Ar<:AbstractArray,F<:AbstractFloat} - α = app.optimiser_state.eta - Q!(app, s, s_next, a, α, r, γ) + Q!(approx, s, s_next, a, r, γ) end function RLBase.optimise!( diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl index 2affa7b14..4b1af10c7 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl @@ -20,15 +20,21 @@ end # Test Q! function @testset "Q! function" begin - approximator = TabularQApproximator(n_state=5, n_action=3) + approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) s = 1 s_plus_one = 2 a = 3 - α = 0.1 - π_ = 0.5 + α = 1/(1 + approximator.optimiser_state.gamma) + π_ = 5.0 γ = 0.9 - RLFarm.Q!(approximator, s, s_plus_one, a, α, π_, γ) - @test RLFarm.Q(approximator, s, a) ≈ α * (π_ + γ * maximum(RLFarm.Q(approximator, s_plus_one)) - RLFarm.Q(approximator, s, a)) + approximator.model[2, s_plus_one] = 15 + approximator.model[a, s] = 2 + + # Following https://en.wikipedia.org/wiki/Q-learning#Algorithm + q_should_be = (1-α) * RLFarm.Q(approximator, s, a) + α * (π_ + γ * maximum(RLFarm.Q(approximator, s_plus_one))) + + @test RLFarm.Q!(approximator, s, s_plus_one, a, π_, γ) ≈ q_should_be + @test RLFarm.Q(approximator, s, a) ≈ q_should_be end # Test optimise! function From f10a4ff6e2757217182b16cf9354933b8192e315 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 17:36:46 +0100 Subject: [PATCH 04/22] Add test for TDLearner and optimise! function --- .../test/algorithms/tabular/tabular.jl | 6 +++++- .../test/algorithms/tabular/td_learner.jl | 11 ++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl index 375c44b32..4bb386172 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl @@ -1 +1,5 @@ -include("td_learner.jl") +@testset "TDLearner" begin + include("td_learner.jl") +end + + diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl index 4b1af10c7..36582c324 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl @@ -39,9 +39,14 @@ end # Test optimise! function @testset "optimise! function" begin - approximator = TabularQApproximator(n_state=5, n_action=3) + approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) learner = TDLearner(approximator, :SARS) - t = (state=1, next_state=2, action=3, reward=0.5, terminal=false) + + t = (state=1, next_state=2, action=3, reward=5.0, terminal=false) optimise!(learner, t) - @test true # Add your own assertion here + @test approximator.model[t.action, t.state] ≈ 2.9411764705882355 + for i in 1:500 + optimise!(learner, t) + end + @test approximator.model[t.action, t.state] ≈ t.reward atol=0.01 end From 862d08d256ed751dba38eb762f4aa7a96016f732 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 17:39:28 +0100 Subject: [PATCH 05/22] Refactor tabular approximators to force optimization functions specification --- .../src/policies/learners/tabular_approximator.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl index ef76d2a7e..f84b02856 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl @@ -19,10 +19,10 @@ function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O} TabularApproximator{A,O}(table, opt) end -TabularVApproximator(; n_state, init = 0.0, opt = InvDecay(1.0)) = +TabularVApproximator(; n_state, opt, init = 0.0) = TabularApproximator(fill(init, n_state), opt) -TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) = +TabularQApproximator(; n_state, n_action, opt, init = 0.0) = TabularApproximator(fill(init, n_action, n_state), opt) # Take Learner and Environment, get state, send to RLCore.forward(Learner, State) From d080cf73d13758c7ff19908c07dd16b38dd728e0 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 17:55:35 +0100 Subject: [PATCH 06/22] Update TabularApproximator constructors and tests --- .../src/policies/learners/approximator.jl | 2 +- .../test/policies/learners/tabular_approximator.jl | 10 +++++----- .../test/policies/q_based_policy.jl | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/approximator.jl index 43fb9b955..1ffe7da48 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/approximator.jl @@ -25,7 +25,7 @@ Constructs an `Approximator` object for reinforcement learning. # Returns An `Approximator` object. """ -function Approximator(; model, optimiser::Flux.Optimise.AbstractOptimiser, use_gpu=false) +function Approximator(; model, optimiser, use_gpu=false) optimiser_state = Flux.setup(optimiser, model) if use_gpu # Pass model to GPU (if available) upon creation return Approximator(gpu(model), gpu(optimiser_state)) diff --git a/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl b/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl index 25c3b2e86..f23caf063 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl @@ -4,21 +4,21 @@ using ReinforcementLearningCore using Flux @testset "Constructors" begin - @test TabularApproximator(fill(1, 10, 10), fill(1, 10)) isa TabularApproximator - @test TabularVApproximator(n_state = 10) isa + @test TabularApproximator(fill(1, 10, 10), InvDecay(0.5)) isa TabularApproximator + @test TabularVApproximator(n_state = 10, opt = InvDecay(0.5)) isa TabularApproximator{Vector{Float64},InvDecay} - @test TabularQApproximator(n_state = 10, n_action = 10) isa + @test TabularQApproximator(n_state = 10, n_action = 10, opt = InvDecay(0.5)) isa TabularApproximator{Matrix{Float64},InvDecay} end @testset "RLCore.forward" begin - v_approx = TabularVApproximator(n_state = 10) + v_approx = TabularVApproximator(n_state = 10, opt = InvDecay(0.5)) @test RLCore.forward(v_approx, 1) == 0.0 env = RockPaperScissorsEnv() @test RLCore.forward(v_approx, env) == 0.0 - q_approx = TabularQApproximator(n_state = 5, n_action = 10) + q_approx = TabularQApproximator(n_state = 5, n_action = 10, opt = InvDecay(0.5)) @test RLCore.forward(q_approx, 1) == zeros(Float64, 10) @test RLCore.forward(q_approx, 1, 5) == 0.0 @test RLCore.forward(q_approx, env) == zeros(10) diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index 166f0afcb..8497512c9 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -1,7 +1,7 @@ @testset "QBasedPolicy" begin @testset "constructor" begin - q_approx = TabularQApproximator(n_state = 5, n_action = 10) + q_approx = TabularQApproximator(n_state = 5, n_action = 10, opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) p = QBasedPolicy(q_approx, explorer) @test p.learner == q_approx @@ -11,7 +11,7 @@ @testset "plan!" begin @testset "plan! without player argument" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(q_approx, explorer) @test 1 <= RLBase.plan!(policy, env) <= 9 @@ -19,7 +19,7 @@ @testset "plan! with player argument" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(q_approx, explorer) player = :player1 @@ -30,7 +30,7 @@ # Test prob function @testset "prob" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(q_approx, explorer) # prob = RLBase.prob(p, env) @@ -40,7 +40,7 @@ # Test optimise! function @testset "optimise!" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(q_approx, explorer) s = PreActStage() From 11449b51d7a91bd00685186ecce3b0f97cf27ce4 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 17:57:30 +0100 Subject: [PATCH 07/22] Update CartPoleEnv initialization to use Float32 --- .../test/policies/learners/approximator.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReinforcementLearningCore/test/policies/learners/approximator.jl b/src/ReinforcementLearningCore/test/policies/learners/approximator.jl index 53bff4c60..df8efeb20 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/approximator.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/approximator.jl @@ -29,7 +29,7 @@ using Flux optimiser = Adam() approximator = Approximator(model=model, optimiser=optimiser, use_gpu=false) - env = CartPoleEnv() + env = CartPoleEnv(T=Float32) output = RLCore.forward(approximator, env) @test typeof(output) == Array{Float32,1} @test length(output) == 2 From 845aad4d812f0060f69e8af75bc388cea71f2432 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 18:05:35 +0100 Subject: [PATCH 08/22] add missing test dependency to ci --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 016733c3f..0dcb7870a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -95,6 +95,7 @@ jobs: Pkg.develop(path="src/ReinforcementLearningBase") Pkg.develop(path="src/ReinforcementLearningCore") Pkg.develop(path="src/ReinforcementLearningEnvironments") + Pkg.develop(path="src/ReinforcementLearningFarm") Pkg.test("ReinforcementLearningCore", coverage=true)' - uses: julia-actions/julia-processcoverage@v1 with: From 527f3e767201a8161e569ab0670447a9d09a9612 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 18:06:41 +0100 Subject: [PATCH 09/22] fix tests --- .../test/algorithms/tabular/td_learner.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl index 36582c324..92be81442 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl @@ -2,7 +2,7 @@ using Test using Flux @testset "Test TDLearner creation" begin - approximator = TabularVApproximator(n_state=5) + approximator = TabularVApproximator(n_state=5, opt = InvDecay(0.7)) @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner approximator = TabularQApproximator(n_state=5, n_action=3) @@ -11,7 +11,7 @@ end # Test TDLearner struct @testset "TDLearner struct" begin - approximator = TabularQApproximator(n_state=5, n_action=3) + approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) learner = TDLearner(approximator, :SARS) @test learner.approximator === approximator @test learner.γ == 1.0 From aecc739bfeb6403aed39f841404165dca4fd7e3b Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 18:37:46 +0100 Subject: [PATCH 10/22] fix missing arg --- .../test/algorithms/tabular/td_learner.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl index 92be81442..86293e37d 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl @@ -5,7 +5,7 @@ using Flux approximator = TabularVApproximator(n_state=5, opt = InvDecay(0.7)) @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner - approximator = TabularQApproximator(n_state=5, n_action=3) + approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner end From 98d8831a6c612f18a29ac1e41a5a88fd01be989c Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 15 Mar 2024 18:48:31 +0100 Subject: [PATCH 11/22] Add ReinforcementLearningFarm package to development dependencies --- .buildkite/pipeline.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 26803e88b..0a9c3e707 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -17,6 +17,7 @@ steps: Pkg.develop(path="src/ReinforcementLearningBase") Pkg.develop(path="src/ReinforcementLearningEnvironments") Pkg.develop(path="src/ReinforcementLearningCore") + Pkg.develop(path="src/ReinforcementLearningFarm") println("+++ :julia: Running tests") Pkg.test("ReinforcementLearningCore", coverage=true) From 8f9b60632c6f9f754f9762b634b0dde229a835a8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 10:59:14 +0100 Subject: [PATCH 12/22] per pr review --- .../src/algorithms/tabular/td_learner.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl index 9d4ed0e3d..f745302ee 100644 --- a/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl @@ -39,24 +39,24 @@ Q(app::TabularApproximator, s, a) = RLCore.forward(app, s, a) Q(app::TabularApproximator, s) = RLCore.forward(app, s) """ - Q!(app::TabularApproximator, s::Int, s_plus_one::Int, a::Int, α::Float64, π_::Float64, γ::Float64) + bellman_update!(app::TabularApproximator, s::Int, s_plus_one::Int, a::Int, α::Float64, π_::Float64, γ::Float64) Update the Q-value of the given state-action pair. """ -function Q!( +function bellman_update!( approx::TabularApproximator, s::I1, s_plus_one::I2, a::I3, - π_::F1, + r::F1, # reward γ::Float64, # discount factor ) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat} # Q-learning formula following https://github.com/JuliaPOMDP/TabularTDLearning.jl/blob/25c4d3888e178c51ed1ff448f36b0fcaf7c1d8e8/src/q_learn.jl#LL63C26-L63C95 # Terminology following https://en.wikipedia.org/wiki/Q-learning estimate_optimal_future_value = maximum(Q(approx, s_plus_one)) current_value = Q(approx, s, a) - raw_q_value = (π_ + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here - q_value_updated = Flux.Optimise.apply!(approx.optimiser_state, :learning, [raw_q_value])[] # adust according to optimiser learning rate + raw_q_value = (r + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here + q_value_updated = Flux.Optimise.update!(approx.optimiser_state, :learning, [raw_q_value])[] # adust according to optimiser learning rate approx.model[a, s] += q_value_updated return Q(approx, s, a) end @@ -70,7 +70,7 @@ function _optimise!( a::I3, r::F, ) where {I1<:Number,I2<:Number,I3<:Number,Ar<:AbstractArray,F<:AbstractFloat} - Q!(approx, s, s_next, a, r, γ) + bellman_update!(approx, s, s_next, a, r, γ) end function RLBase.optimise!( From 6c96eaa34ca7fe3395fbbb2bc915cb7f1dd5baef Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 11:23:17 +0100 Subject: [PATCH 13/22] fix type signature --- .../src/policies/q_based_policy.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl index b712c3144..5a68848d9 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl @@ -19,16 +19,16 @@ end Flux.@layer QBasedPolicy trainable=(learner,) -function RLBase.plan!(p::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv} - RLBase.plan!(p.explorer, p.learner, env) +function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv} + RLBase.plan!(policy.explorer, policy.learner, env) end -function RLBase.plan!(p::QBasedPolicy{L,Ex}, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv} - RLBase.plan!(p.explorer, p.learner, env, player) +function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv} + RLBase.plan!(policy.explorer, policy.learner, env, player) end -RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} = - prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env)) +RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} = + prob(policy.explorer, forward(policy.learner, env), legal_action_space_mask(env)) #the internal learner defines the optimization stage. -RLBase.optimise!(p::QBasedPolicy, s::AbstractStage, trajectory::Trajectory) = RLBase.optimise!(p.learner, s, trajectory) +RLBase.optimise!(policy::QBasedPolicy, stage::AbstractStage, trajectory::Trajectory) = RLBase.optimise!(policy.learner, stage, trajectory) From 7cee80e93444687eb3a62c44ceaff293d5ee9957 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 12:01:26 +0100 Subject: [PATCH 14/22] move tdlearner to rlcore --- .../src/policies/agent/agent_base.jl | 2 +- .../explorers/epsilon_greedy_explorer.jl | 14 +++++++------- .../src/policies/learners/learners.jl | 1 + .../src/policies/learners}/td_learner.jl | 0 .../src/policies/q_based_policy.jl | 14 ++++++++++---- .../test/policies/learners/learners.jl | 2 ++ .../test/policies/learners}/td_learner.jl | 2 +- .../test/policies/q_based_policy.jl | 19 ++++++++++++++++++- .../src/algorithms/tabular/tabular.jl | 1 - .../test/algorithms/tabular/tabular.jl | 5 ----- 10 files changed, 40 insertions(+), 20 deletions(-) rename src/{ReinforcementLearningFarm/src/algorithms/tabular => ReinforcementLearningCore/src/policies/learners}/td_learner.jl (100%) rename src/{ReinforcementLearningFarm/test/algorithms/tabular => ReinforcementLearningCore/test/policies/learners}/td_learner.jl (99%) diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl index adbc0bb84..5769a9fa7 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/agent_base.jl @@ -37,7 +37,7 @@ RLBase.optimise!(::SyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S # already spawn a task to optimise inner policy when initializing the agent RLBase.optimise!(::AsyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {S<:AbstractStage} = nothing -#by default, optimise does nothing at all stage +#by default, optimise does nothing at all stages function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end Flux.@layer Agent trainable=(policy,) diff --git a/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl b/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl index f2d9ce559..7d8925e87 100644 --- a/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl +++ b/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl @@ -99,13 +99,13 @@ get_ϵ(s::EpsilonGreedyExplorer) = get_ϵ(s, s.step) `NaN` will be filtered unless all the values are `NaN`. In that case, a random one will be returned. """ -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:AbstractArray{I}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? rand(s.rng, find_all_max(values)[2]) : rand(s.rng, 1:length(values)) end -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A) where {I<:Real, A<:AbstractArray{I}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? findmax(values)[2] : rand(s.rng, 1:length(values)) @@ -113,18 +113,18 @@ end ##### -RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, x::A, mask::Trues) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} = RLBase.plan!(s, x) +RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, x::A, mask::Trues) where {I<:Real, A<:AbstractArray{I}} = RLBase.plan!(s, x) -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A, mask::M) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}, M<:Union{BitVector, Vector{Bool}}} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,true}, values::A, mask::M) where {I<:Real, A<:AbstractArray{I}, M<:Union{BitVector, Vector{Bool}}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? rand(s.rng, find_all_max(values, mask)[2]) : rand(s.rng, findall(mask)) end -RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::A, mask::Trues) where{I<:Real, A<:Union{Vector{I}, SubArray{I}}} = RLBase.plan!(s, x) +RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::A, mask::Trues) where{I<:Real, A<:AbstractArray{I}} = RLBase.plan!(s, x) -function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A, mask::M) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}, M<:Union{BitVector, Vector{Bool}}} +function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::A, mask::M) where {I<:Real, A<:AbstractArray{I}, M<:Union{BitVector, Vector{Bool}}} ϵ = get_ϵ(s) s.step += 1 rand(s.rng) >= ϵ ? findmax_masked(values, mask)[2] : rand(s.rng, findall(mask)) @@ -138,7 +138,7 @@ end Return the probability of selecting each action given the estimated `values` of each action. """ -function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:Union{Vector{I}, SubArray{I}}} +function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,true}, values::A) where {I<:Real, A<:AbstractArray{I}} ϵ, n = get_ϵ(s), length(values) probs = fill(ϵ / n, n) max_val_inds = find_all_max(values)[2] diff --git a/src/ReinforcementLearningCore/src/policies/learners/learners.jl b/src/ReinforcementLearningCore/src/policies/learners/learners.jl index 28fafdd9f..6ccd38041 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/learners.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/learners.jl @@ -1,4 +1,5 @@ include("abstract_learner.jl") include("approximator.jl") include("tabular_approximator.jl") +include("td_learner.jl") include("target_network.jl") diff --git a/src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl similarity index 100% rename from src/ReinforcementLearningFarm/src/algorithms/tabular/td_learner.jl rename to src/ReinforcementLearningCore/src/policies/learners/td_learner.jl diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl index 5a68848d9..a10c663ee 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl @@ -10,7 +10,7 @@ action of an environment at its current state. It is typically a table or a neur QBasedPolicy can be queried for an action with `RLBase.plan!`, the explorer will affect the action selection accordingly. """ -Base.@kwdef mutable struct QBasedPolicy{L,E} <: AbstractPolicy +Base.@kwdef mutable struct QBasedPolicy{L<:TDLearner,E<:AbstractExplorer} <: AbstractPolicy "estimate the Q value" learner::L "select the action based on Q values calculated by the learner" @@ -19,16 +19,22 @@ end Flux.@layer QBasedPolicy trainable=(learner,) -function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv} +function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:TDLearner,E<:AbstractEnv} RLBase.plan!(policy.explorer, policy.learner, env) end -function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv} +function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:TDLearner,E<:AbstractEnv} RLBase.plan!(policy.explorer, policy.learner, env, player) end -RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} = +RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:TDLearner,Ex<:AbstractExplorer} = prob(policy.explorer, forward(policy.learner, env), legal_action_space_mask(env)) #the internal learner defines the optimization stage. RLBase.optimise!(policy::QBasedPolicy, stage::AbstractStage, trajectory::Trajectory) = RLBase.optimise!(policy.learner, stage, trajectory) + +function RLBase.optimise!(policy::QBasedPolicy, stage::AbstractStage, trajectory::Trajectory) + for batch in trajectory.container + optimise!(policy.learner, stage, batch) + end +end diff --git a/src/ReinforcementLearningCore/test/policies/learners/learners.jl b/src/ReinforcementLearningCore/test/policies/learners/learners.jl index 94b6e0487..d336799eb 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/learners.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/learners.jl @@ -1,5 +1,7 @@ @testset "approximators.jl" begin + include("abstract_learner.jl") include("approximator.jl") include("tabular_approximator.jl") include("target_network.jl") + include("td_learner.jl") end diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl b/src/ReinforcementLearningCore/test/policies/learners/td_learner.jl similarity index 99% rename from src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl rename to src/ReinforcementLearningCore/test/policies/learners/td_learner.jl index 86293e37d..3aae407e4 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/td_learner.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/td_learner.jl @@ -41,7 +41,7 @@ end @testset "optimise! function" begin approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) learner = TDLearner(approximator, :SARS) - + t = (state=1, next_state=2, action=3, reward=5.0, terminal=false) optimise!(learner, t) @test approximator.model[t.action, t.state] ≈ 2.9411764705882355 diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index 8497512c9..c5a1deed6 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -31,9 +31,26 @@ @testset "prob" begin env = TicTacToeEnv() q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(q_approx, explorer) - # prob = RLBase.prob(p, env) + trajectory = Trajectory( + CircularArraySARTSTraces(; + capacity = 1, + state = Int64 => (), + action = Int8 => (), + reward = Float64 => (), + terminal = Bool => (), + ), + DummySampler(), + InsertSampleRatioController(), + ) + t = (state=2, action=3) + push!(trajectory, t) + t = (next_state=3, reward=5.0, terminal=false) + push!(trajectory, t) + optimise!(policy, PostActStage(), trajectory) + prob = RLBase.prob(policy, env) # Add assertions here end diff --git a/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl b/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl index 375c44b32..e69de29bb 100644 --- a/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl +++ b/src/ReinforcementLearningFarm/src/algorithms/tabular/tabular.jl @@ -1 +0,0 @@ -include("td_learner.jl") diff --git a/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl b/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl index 4bb386172..e69de29bb 100644 --- a/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl +++ b/src/ReinforcementLearningFarm/test/algorithms/tabular/tabular.jl @@ -1,5 +0,0 @@ -@testset "TDLearner" begin - include("td_learner.jl") -end - - From e9af3b7f63e89741f501a05497b918e0d8135687 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 12:26:36 +0100 Subject: [PATCH 15/22] rewire qlearning --- .../src/policies/learners/td_learner.jl | 2 ++ .../test/policies/q_based_policy.jl | 24 +++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl index f745302ee..4fc604032 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl @@ -79,3 +79,5 @@ function RLBase.optimise!( ) where {I1<:Number,I2<:Number,F2<:AbstractFloat} _optimise!(L.n, L.γ, L.approximator, t.state, t.next_state, t.action, t.reward) end + +RLBase.optimise!(L::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(L, trace) diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index c5a1deed6..e0a53fd4c 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -33,7 +33,7 @@ q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) - policy = QBasedPolicy(q_approx, explorer) + policy = QBasedPolicy(learner, explorer) trajectory = Trajectory( CircularArraySARTSTraces(; capacity = 1, @@ -59,14 +59,24 @@ env = TicTacToeEnv() q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) - policy = QBasedPolicy(q_approx, explorer) - s = PreActStage() + learner = TDLearner(q_approx, :SARS) + policy = QBasedPolicy(learner, explorer) trajectory = Trajectory( - CircularArraySARTSTraces(; capacity = 1), - BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + CircularArraySARTSTraces(; + capacity = 1, + state = Int64 => (), + action = Int8 => (), + reward = Float64 => (), + terminal = Bool => (), + ), + DummySampler(), + InsertSampleRatioController(), ) - RLBase.optimise!(policy, s, trajectory) + t = (state=2, action=3) + push!(trajectory, t) + t = (next_state=3, reward=5.0, terminal=false) + push!(trajectory, t) + RLBase.optimise!(policy, PostActStage(), trajectory) # Add assertions here end end From 87760ff0c9f132c0292b1f85cbb8d283645f563d Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 15:02:32 +0100 Subject: [PATCH 16/22] fix tests --- .../src/policies/learners/td_learner.jl | 16 ++++++++++++---- .../src/policies/q_based_policy.jl | 6 ------ .../test/policies/q_based_policy.jl | 15 +++++++++------ 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl index 4fc604032..368f05c86 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl @@ -32,11 +32,11 @@ Base.@kwdef mutable struct TDLearner{M,A} <: AbstractLearner where {A<:TabularAp end end -RLCore.forward(L::TDLearner, s) = RLCore.forward(L.approximator, s) -RLCore.forward(L::TDLearner, s, a) = RLCore.forward(L.approximator, s, a) +RLCore.forward(L::TDLearner, s::Int) = RLCore.forward(L.approximator, s) +RLCore.forward(L::TDLearner, s::Int, a::Int) = RLCore.forward(L.approximator, s, a) -Q(app::TabularApproximator, s, a) = RLCore.forward(app, s, a) -Q(app::TabularApproximator, s) = RLCore.forward(app, s) +Q(app::TabularApproximator, s::Int, a::Int) = RLCore.forward(app, s, a) +Q(app::TabularApproximator, s::Int) = RLCore.forward(app, s) """ bellman_update!(app::TabularApproximator, s::Int, s_plus_one::Int, a::Int, α::Float64, π_::Float64, γ::Float64) @@ -80,4 +80,12 @@ function RLBase.optimise!( _optimise!(L.n, L.γ, L.approximator, t.state, t.next_state, t.action, t.reward) end +function RLBase.optimise!(learner::TDLearner, stage, trajectory::Trajectory) + for batch in trajectory.container + optimise!(learner, stage, batch) + end +end + +# TDLearner{:SARS} optimises at the PostActStage RLBase.optimise!(L::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(L, trace) + diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl index a10c663ee..53774e23c 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl @@ -32,9 +32,3 @@ RLBase.prob(policy::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:TDLearner,Ex #the internal learner defines the optimization stage. RLBase.optimise!(policy::QBasedPolicy, stage::AbstractStage, trajectory::Trajectory) = RLBase.optimise!(policy.learner, stage, trajectory) - -function RLBase.optimise!(policy::QBasedPolicy, stage::AbstractStage, trajectory::Trajectory) - for batch in trajectory.container - optimise!(policy.learner, stage, batch) - end -end diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index e0a53fd4c..501bac32a 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -3,8 +3,9 @@ @testset "constructor" begin q_approx = TabularQApproximator(n_state = 5, n_action = 10, opt = InvDecay(0.5)) explorer = EpsilonGreedyExplorer(0.1) - p = QBasedPolicy(q_approx, explorer) - @test p.learner == q_approx + learner = TDLearner(q_approx, :SARS) + p = QBasedPolicy(learner, explorer) + @test p.learner == learner @test p.explorer == explorer end @@ -12,16 +13,18 @@ @testset "plan! without player argument" begin env = TicTacToeEnv() q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) - policy = QBasedPolicy(q_approx, explorer) + policy = QBasedPolicy(learner, explorer) @test 1 <= RLBase.plan!(policy, env) <= 9 end @testset "plan! with player argument" begin env = TicTacToeEnv() q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) - policy = QBasedPolicy(q_approx, explorer) + policy = QBasedPolicy(learner, explorer) player = :player1 @test 1 <= RLBase.plan!(policy, env) <= 9 end @@ -38,7 +41,7 @@ CircularArraySARTSTraces(; capacity = 1, state = Int64 => (), - action = Int8 => (), + action = Int64 => (), reward = Float64 => (), terminal = Bool => (), ), @@ -65,7 +68,7 @@ CircularArraySARTSTraces(; capacity = 1, state = Int64 => (), - action = Int8 => (), + action = Int64 => (), reward = Float64 => (), terminal = Bool => (), ), From b1c5c125601b8304ef31eece8cff7df552c19bcb Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 15:08:25 +0100 Subject: [PATCH 17/22] naming --- .../src/policies/learners/td_learner.jl | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl index 368f05c86..544b2defc 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl @@ -45,32 +45,32 @@ Update the Q-value of the given state-action pair. """ function bellman_update!( approx::TabularApproximator, - s::I1, - s_plus_one::I2, - a::I3, - r::F1, # reward + state::I1, + next_state::I2, + action::I3, + reward::F1, γ::Float64, # discount factor ) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat} # Q-learning formula following https://github.com/JuliaPOMDP/TabularTDLearning.jl/blob/25c4d3888e178c51ed1ff448f36b0fcaf7c1d8e8/src/q_learn.jl#LL63C26-L63C95 # Terminology following https://en.wikipedia.org/wiki/Q-learning - estimate_optimal_future_value = maximum(Q(approx, s_plus_one)) - current_value = Q(approx, s, a) - raw_q_value = (r + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here + estimate_optimal_future_value = maximum(Q(approx, next_state)) + current_value = Q(approx, state, action) + raw_q_value = (reward + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here q_value_updated = Flux.Optimise.update!(approx.optimiser_state, :learning, [raw_q_value])[] # adust according to optimiser learning rate - approx.model[a, s] += q_value_updated - return Q(approx, s, a) + approx.model[action, state] += q_value_updated + return Q(approx, state, action) end function _optimise!( n::I1, γ::F, approx::Approximator{Ar}, - s::I2, - s_next::I2, - a::I3, - r::F, + state::I2, + next_state::I2, + action::I3, + reward::F, ) where {I1<:Number,I2<:Number,I3<:Number,Ar<:AbstractArray,F<:AbstractFloat} - bellman_update!(approx, s, s_next, a, r, γ) + bellman_update!(approx, state, next_state, action, reward, γ) end function RLBase.optimise!( @@ -80,12 +80,12 @@ function RLBase.optimise!( _optimise!(L.n, L.γ, L.approximator, t.state, t.next_state, t.action, t.reward) end -function RLBase.optimise!(learner::TDLearner, stage, trajectory::Trajectory) +function RLBase.optimise!(learner::TDLearner, stage::AbstractStage, trajectory::Trajectory) for batch in trajectory.container optimise!(learner, stage, batch) end end # TDLearner{:SARS} optimises at the PostActStage -RLBase.optimise!(L::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(L, trace) +RLBase.optimise!(learner::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(learner, trace) From 5aed0b2d6c9aa173372e7754e1217aef9bf1a4da Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 15:41:57 +0100 Subject: [PATCH 18/22] fixes --- .../src/policies/learners/tabular_approximator.jl | 3 ++- src/ReinforcementLearningCore/test/policies/q_based_policy.jl | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl index f84b02856..9ad743e32 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl @@ -16,7 +16,8 @@ For `table` of 2-d, it will serve as a state-action value approximator. function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O} n = ndims(table) n <= 2 || throw(ArgumentError("the dimension of table must be <= 2")) - TabularApproximator{A,O}(table, opt) + optimiser_state = Flux.setup(optimiser, table) + TabularApproximator{A,O}(table, optimiser_state) end TabularVApproximator(; n_state, opt, init = 0.0) = diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index 501bac32a..991096511 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -77,8 +77,10 @@ ) t = (state=2, action=3) push!(trajectory, t) - t = (next_state=3, reward=5.0, terminal=false) + next_state = 4 + t = (action=3, state=next_state, reward=5.0, terminal=false) push!(trajectory, t) + trajectory.container[1] RLBase.optimise!(policy, PostActStage(), trajectory) # Add assertions here end From 1366548e069d8f19f2815b920fc1cbcacb0ecf7c Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 17:42:18 +0100 Subject: [PATCH 19/22] Update learners.jl and related files --- .../src/policies/learners/approximator.jl | 45 ------------------ .../learners/flux_model_approximator.jl | 47 +++++++++++++++++++ .../src/policies/learners/learners.jl | 2 +- .../policies/learners/tabular_approximator.jl | 36 +++++++------- .../src/policies/learners/target_network.jl | 20 ++++---- .../src/policies/learners/td_learner.jl | 22 +++++---- .../policies/learners/abstract_learner.jl | 39 +++++++-------- ...oximator.jl => flux_model_approximator.jl} | 12 ++--- .../test/policies/learners/learners.jl | 2 +- .../policies/learners/tabular_approximator.jl | 14 +++--- .../test/policies/learners/target_network.jl | 17 +++---- .../test/policies/learners/td_learner.jl | 32 +++++++------ .../test/policies/q_based_policy.jl | 13 ++--- .../test/utils/networks.jl | 2 +- 14 files changed, 156 insertions(+), 147 deletions(-) delete mode 100644 src/ReinforcementLearningCore/src/policies/learners/approximator.jl create mode 100644 src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl rename src/ReinforcementLearningCore/test/policies/learners/{approximator.jl => flux_model_approximator.jl} (74%) diff --git a/src/ReinforcementLearningCore/src/policies/learners/approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/approximator.jl deleted file mode 100644 index 1ffe7da48..000000000 --- a/src/ReinforcementLearningCore/src/policies/learners/approximator.jl +++ /dev/null @@ -1,45 +0,0 @@ -using Flux - -""" - Approximator(model, optimiser) - -Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)` -interface. See the RLCore documentation for more information on proper usage. -""" -struct Approximator{M,O} <: AbstractLearner - model::M - optimiser_state::O -end - - -""" - Approximator(; model, optimiser, usegpu=false) - -Constructs an `Approximator` object for reinforcement learning. - -# Arguments -- `model`: The model used for approximation. -- `optimiser`: The optimizer used for updating the model. -- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`. - -# Returns -An `Approximator` object. -""" -function Approximator(; model, optimiser, use_gpu=false) - optimiser_state = Flux.setup(optimiser, model) - if use_gpu # Pass model to GPU (if available) upon creation - return Approximator(gpu(model), gpu(optimiser_state)) - else - return Approximator(model, optimiser_state) - end -end - -Approximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = Approximator(model=model, optimiser=optimiser, use_gpu=use_gpu) - -Flux.@layer Approximator trainable=(model,) - -forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...) -forward(A::Approximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x)) - -RLBase.optimise!(A::Approximator, grad::NamedTuple) = - Flux.Optimise.update!(A.optimiser_state, A.model, grad.model) diff --git a/src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl new file mode 100644 index 000000000..2e203e6ac --- /dev/null +++ b/src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl @@ -0,0 +1,47 @@ +export FluxModelApproximator + +using Flux + +""" + FluxModelApproximator(model, optimiser) + +Wraps a Flux trainable model and implements the `RLBase.optimise!(::FluxModelApproximator, ::Gradient)` +interface. See the RLCore documentation for more information on proper usage. +""" +struct FluxModelApproximator{M,O} <: AbstractLearner + model::M + optimiser_state::O +end + + +""" + FluxModelApproximator(; model, optimiser, usegpu=false) + +Constructs an `FluxModelApproximator` object for reinforcement learning. + +# Arguments +- `model`: The model used for approximation. +- `optimiser`: The optimizer used for updating the model. +- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`. + +# Returns +An `FluxModelApproximator` object. +""" +function FluxModelApproximator(; model, optimiser, use_gpu=false) + optimiser_state = Flux.setup(optimiser, model) + if use_gpu # Pass model to GPU (if available) upon creation + return FluxModelApproximator(gpu(model), gpu(optimiser_state)) + else + return FluxModelApproximator(model, optimiser_state) + end +end + +FluxModelApproximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=use_gpu) + +Flux.@layer FluxModelApproximator trainable=(model,) + +forward(A::FluxModelApproximator, args...; kwargs...) = A.model(args...; kwargs...) +forward(A::FluxModelApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x)) + +RLBase.optimise!(A::FluxModelApproximator, grad::NamedTuple) = + Flux.Optimise.update!(A.optimiser_state, A.model, grad.model) diff --git a/src/ReinforcementLearningCore/src/policies/learners/learners.jl b/src/ReinforcementLearningCore/src/policies/learners/learners.jl index 6ccd38041..b925f58ec 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/learners.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/learners.jl @@ -1,5 +1,5 @@ include("abstract_learner.jl") -include("approximator.jl") +include("flux_model_approximator.jl") include("tabular_approximator.jl") include("td_learner.jl") include("target_network.jl") diff --git a/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl index 9ad743e32..546660279 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/tabular_approximator.jl @@ -1,11 +1,14 @@ export TabularApproximator, TabularVApproximator, TabularQApproximator -const TabularApproximator = Approximator{A,O} where {A<:AbstractArray,O} -const TabularQApproximator = Approximator{A,O} where {A<:AbstractArray,O} -const TabularVApproximator = Approximator{A,O} where {A<:AbstractVector,O} +struct TabularApproximator{A} <: AbstractLearner where {A<:AbstractArray} + model::A +end + +const TabularQApproximator = TabularApproximator{A} where {A<:AbstractMatrix} +const TabularVApproximator = TabularApproximator{A} where {A<:AbstractVector} """ - TabularApproximator(table<:AbstractArray, opt) + TabularApproximator(table<:AbstractArray) For `table` of 1-d, it will serve as a state value approximator. For `table` of 2-d, it will serve as a state-action value approximator. @@ -13,35 +16,34 @@ For `table` of 2-d, it will serve as a state-action value approximator. !!! warning For `table` of 2-d, the first dimension is action and the second dimension is state. """ -function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O} +function TabularApproximator(table::A) where {A<:AbstractArray} n = ndims(table) n <= 2 || throw(ArgumentError("the dimension of table must be <= 2")) - optimiser_state = Flux.setup(optimiser, table) - TabularApproximator{A,O}(table, optimiser_state) + TabularApproximator{A}(table) end -TabularVApproximator(; n_state, opt, init = 0.0) = - TabularApproximator(fill(init, n_state), opt) +TabularVApproximator(; n_state, init = 0.0) = + TabularApproximator(fill(init, n_state)) -TabularQApproximator(; n_state, n_action, opt, init = 0.0) = - TabularApproximator(fill(init, n_action, n_state), opt) +TabularQApproximator(; n_state, n_action, init = 0.0) = + TabularApproximator(fill(init, n_action, n_state)) # Take Learner and Environment, get state, send to RLCore.forward(Learner, State) forward(L::TabularVApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x)) forward(L::TabularQApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(L, x)) RLCore.forward( - app::TabularVApproximator{R,O}, + app::TabularVApproximator{R}, s::I, -) where {R<:AbstractVector,O,I} = @views app.model[s] +) where {R<:AbstractVector,I} = @views app.model[s] RLCore.forward( - app::TabularQApproximator{R,O}, + app::TabularQApproximator{R}, s::I, -) where {R<:AbstractArray,O,I} = @views app.model[:, s] +) where {R<:AbstractArray,I} = @views app.model[:, s] RLCore.forward( - app::TabularQApproximator{R,O}, + app::TabularQApproximator{R}, s::I1, a::I2, -) where {R<:AbstractArray,O,I1,I2} = @views app.model[a, s] +) where {R<:AbstractArray,I1,I2} = @views app.model[a, s] diff --git a/src/ReinforcementLearningCore/src/policies/learners/target_network.jl b/src/ReinforcementLearningCore/src/policies/learners/target_network.jl index 74003644c..26e32b6cc 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/target_network.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/target_network.jl @@ -1,14 +1,14 @@ -export Approximator, TargetNetwork, target, model +export TargetNetwork, target, model using Flux -target(ap::Approximator) = ap.model #see TargetNetwork -model(ap::Approximator) = ap.model #see TargetNetwork +target(ap::FluxModelApproximator) = ap.model #see TargetNetwork +model(ap::FluxModelApproximator) = ap.model #see TargetNetwork """ - TargetNetwork(network::Approximator; sync_freq::Int = 1, ρ::Float32 = 0f0) + TargetNetwork(network::FluxModelApproximator; sync_freq::Int = 1, ρ::Float32 = 0f0) -Wraps an Approximator to hold a target network that is updated towards the model of the +Wraps an FluxModelApproximator to hold a target network that is updated towards the model of the approximator. - `sync_freq` is the number of updates of `network` between each update of the `target`. - ρ (\rho) is "how much of the target is kept when updating it". @@ -21,11 +21,11 @@ Implements the `RLBase.optimise!(::TargetNetwork, ::Gradient)` interface to upda and the target with weights replacement or Polyak averaging. Note to developers: `model(::TargetNetwork)` will return the trainable Flux model -and `target(::TargetNetwork)` returns the target model and `target(::Approximator)` +and `target(::TargetNetwork)` returns the target model and `target(::FluxModelApproximator)` returns the non-trainable Flux model. See the RLCore documentation. """ mutable struct TargetNetwork{M} - network::Approximator{M} + network::FluxModelApproximator{M} target::M sync_freq::Int ρ::Float32 @@ -46,13 +46,13 @@ Constructs a target network for reinforcement learning. # Returns A `TargetNetwork` object. """ -function TargetNetwork(network::Approximator; sync_freq = 1, ρ = 0f0, use_gpu = false) +function TargetNetwork(network::FluxModelApproximator; sync_freq = 1, ρ = 0f0, use_gpu = false) @assert 0 <= ρ <= 1 "ρ must in [0,1]" ρ = Float32(ρ) if use_gpu - @assert typeof(gpu(network.model)) == typeof(network.model) "`Approximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `Approximator`." - # NOTE: model is pushed to gpu in Approximator, need to transfer to cpu before deepcopy, then push target model to gpu + @assert typeof(gpu(network.model)) == typeof(network.model) "`FluxModelApproximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `FluxModelApproximator`." + # NOTE: model is pushed to gpu in FluxModelApproximator, need to transfer to cpu before deepcopy, then push target model to gpu target = gpu(deepcopy(cpu(network.model))) else target = deepcopy(network.model) diff --git a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl index 544b2defc..175c749ab 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl @@ -8,7 +8,7 @@ using ReinforcementLearningCore: AbstractLearner, TabularApproximator using Flux """ - TDLearner(;approximator, γ=1.0, method, n=0) + TDLearner(;approximator, method, γ=1.0, α=0.01, n=0) Use temporal-difference method to estimate state value or state-action value. @@ -20,14 +20,15 @@ Use temporal-difference method to estimate state value or state-action value. """ Base.@kwdef mutable struct TDLearner{M,A} <: AbstractLearner where {A<:TabularApproximator,M<:Symbol} approximator::A - γ::Float64 = 1.0 + γ::Float64 = 1.0 # discount factor + α::Float64 = 0.01 # learning rate n::Int = 0 - function TDLearner(approximator::A, method::Symbol; γ=1.0, n=0) where {A<:TabularApproximator} + function TDLearner(approximator::A, method::Symbol; γ=1.0, α=0.01, n=0) where {A<:TabularApproximator} if method ∉ [:SARS] @error "Method $method is not supported" else - new{method, A}(approximator, γ, n) + new{method, A}(approximator, γ, α, n) end end end @@ -50,34 +51,35 @@ function bellman_update!( action::I3, reward::F1, γ::Float64, # discount factor + α::Float64, # learning rate ) where {I1<:Integer,I2<:Integer,I3<:Integer,F1<:AbstractFloat} # Q-learning formula following https://github.com/JuliaPOMDP/TabularTDLearning.jl/blob/25c4d3888e178c51ed1ff448f36b0fcaf7c1d8e8/src/q_learn.jl#LL63C26-L63C95 # Terminology following https://en.wikipedia.org/wiki/Q-learning estimate_optimal_future_value = maximum(Q(approx, next_state)) current_value = Q(approx, state, action) raw_q_value = (reward + γ * estimate_optimal_future_value - current_value) # Discount factor γ is applied here - q_value_updated = Flux.Optimise.update!(approx.optimiser_state, :learning, [raw_q_value])[] # adust according to optimiser learning rate - approx.model[action, state] += q_value_updated + approx.model[action, state] += α * raw_q_value return Q(approx, state, action) end function _optimise!( n::I1, - γ::F, - approx::Approximator{Ar}, + γ::F, # discount factor + α::F, # learning rate + approx::TabularApproximator{Ar}, state::I2, next_state::I2, action::I3, reward::F, ) where {I1<:Number,I2<:Number,I3<:Number,Ar<:AbstractArray,F<:AbstractFloat} - bellman_update!(approx, state, next_state, action, reward, γ) + bellman_update!(approx, state, next_state, action, reward, γ, α) end function RLBase.optimise!( L::TDLearner, t::@NamedTuple{state::I1, next_state::I1, action::I2, reward::F2, terminal::Bool}, ) where {I1<:Number,I2<:Number,F2<:AbstractFloat} - _optimise!(L.n, L.γ, L.approximator, t.state, t.next_state, t.action, t.reward) + _optimise!(L.n, L.γ, L.α, L.approximator, t.state, t.next_state, t.action, t.reward) end function RLBase.optimise!(learner::TDLearner, stage::AbstractStage, trajectory::Trajectory) diff --git a/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl b/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl index 52700d3cd..fbf9ec09b 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl @@ -1,32 +1,30 @@ using Test using Flux +using ReinforcementLearningCore, ReinforcementLearningBase + +# Mock explorer, environment, and learner +struct MockExplorer <: AbstractExplorer end +struct MockEnv <: AbstractEnv end +struct MockLearner <: AbstractLearner end @testset "AbstractLearner Tests" begin @testset "Forward" begin - # Mock environment and learner - struct MockEnv <: AbstractEnv end - struct MockLearner <: AbstractLearner end - function RLCore.forward(::MockLearner, ::AbstractState) - return rand(2) + function RLCore.forward(::MockLearner, state::Int) + return [1.0, 2.0] end + RLBase.state(::MockEnv, ::Observation{Any}, ::DefaultPlayer) = 1 + env = MockEnv() learner = MockLearner() - output = forward(learner, env) - - @test typeof(output) == Array{Float64,1} - @test length(output) == 2 + output = RLCore.forward(learner, env) + @test output == Float64[1.0, 2.0] end @testset "Plan" begin - # Mock explorer, environment, and learner - struct MockExplorer <: AbstractExplorer end - struct MockEnv <: AbstractEnv end - struct MockLearner <: AbstractLearner end - - function RLBase.plan!(::MockExplorer, ::AbstractState, ::AbstractActionSpace) + function RLBase.plan!(::MockExplorer, learner::MockLearner, env::MockEnv) return rand(2) end @@ -42,11 +40,11 @@ using Flux @testset "Plan with Player" begin # Mock explorer, environment, and learner - struct MockExplorer <: AbstractExplorer end - struct MockEnv <: AbstractEnv end - struct MockLearner <: AbstractLearner end + function RLBase.action_space(::MockEnv, ::Symbol) + return [1, 2] + end - function RLBase.plan!(::MockExplorer, ::AbstractState, ::AbstractActionSpace) + function RLBase.plan!(::MockExplorer, learner::MockLearner, env::MockEnv, p::Symbol) return rand(2) end @@ -62,12 +60,11 @@ using Flux end @testset "optimise!" begin - struct MockLearner <: AbstractLearner end tr = Trajectory( CircularArraySARTSTraces(; capacity = 1_000), BatchSampler(1), InsertSampleRatioController(n_inserted = -1), ) - @test optimise!(MockLearner(), PreActStage(), tr) is nothing + @test optimise!(MockLearner(), PreActStage(), tr) == nothing end end diff --git a/src/ReinforcementLearningCore/test/policies/learners/approximator.jl b/src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl similarity index 74% rename from src/ReinforcementLearningCore/test/policies/learners/approximator.jl rename to src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl index df8efeb20..fd5359d46 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/approximator.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl @@ -1,13 +1,13 @@ using Test using Flux -@testset "Approximator Tests" begin +@testset "FluxModelApproximator Tests" begin @testset "Creation, with use_gpu = true toggle" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = Approximator(model=model, optimiser=optimiser, use_gpu=true) + approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=true) - @test approximator isa Approximator + @test approximator isa FluxModelApproximator @test typeof(approximator.model) == typeof(gpu(model)) @test approximator.optimiser_state isa NamedTuple end @@ -15,7 +15,7 @@ using Flux @testset "Forward" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = Approximator(model=model, optimiser=optimiser, use_gpu=false) + approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=false) input = rand(Float32, 10) output = RLCore.forward(approximator, input) @@ -27,7 +27,7 @@ using Flux @testset "Forward to environment" begin model = Chain(Dense(4, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = Approximator(model=model, optimiser=optimiser, use_gpu=false) + approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=false) env = CartPoleEnv(T=Float32) output = RLCore.forward(approximator, env) @@ -38,7 +38,7 @@ using Flux @testset "Optimise" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = Approximator(model=model, optimiser=optimiser) + approximator = FluxModelApproximator(model=model, optimiser=optimiser) input = rand(Float32, 10) diff --git a/src/ReinforcementLearningCore/test/policies/learners/learners.jl b/src/ReinforcementLearningCore/test/policies/learners/learners.jl index d336799eb..9d0c474d6 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/learners.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/learners.jl @@ -1,6 +1,6 @@ @testset "approximators.jl" begin include("abstract_learner.jl") - include("approximator.jl") + include("flux_model_approximator.jl") include("tabular_approximator.jl") include("target_network.jl") include("td_learner.jl") diff --git a/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl b/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl index f23caf063..2876e99c1 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/tabular_approximator.jl @@ -4,21 +4,21 @@ using ReinforcementLearningCore using Flux @testset "Constructors" begin - @test TabularApproximator(fill(1, 10, 10), InvDecay(0.5)) isa TabularApproximator - @test TabularVApproximator(n_state = 10, opt = InvDecay(0.5)) isa - TabularApproximator{Vector{Float64},InvDecay} - @test TabularQApproximator(n_state = 10, n_action = 10, opt = InvDecay(0.5)) isa - TabularApproximator{Matrix{Float64},InvDecay} + @test TabularApproximator(fill(1, 10, 10)) isa TabularApproximator + @test TabularVApproximator(n_state = 10) isa + TabularApproximator{Vector{Float64}} + @test TabularQApproximator(n_state = 10, n_action = 10) isa + TabularApproximator{Matrix{Float64}} end @testset "RLCore.forward" begin - v_approx = TabularVApproximator(n_state = 10, opt = InvDecay(0.5)) + v_approx = TabularVApproximator(n_state = 10) @test RLCore.forward(v_approx, 1) == 0.0 env = RockPaperScissorsEnv() @test RLCore.forward(v_approx, env) == 0.0 - q_approx = TabularQApproximator(n_state = 5, n_action = 10, opt = InvDecay(0.5)) + q_approx = TabularQApproximator(n_state = 5, n_action = 10) @test RLCore.forward(q_approx, 1) == zeros(Float64, 10) @test RLCore.forward(q_approx, 1, 5) == 0.0 @test RLCore.forward(q_approx, env) == zeros(10) diff --git a/src/ReinforcementLearningCore/test/policies/learners/target_network.jl b/src/ReinforcementLearningCore/test/policies/learners/target_network.jl index 87c72c30b..564e92f04 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/target_network.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/target_network.jl @@ -1,17 +1,18 @@ using Test using Flux using ReinforcementLearningCore + @testset "TargetNetwork Tests" begin @testset "Creation" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() if ((@isdefined CUDA) && CUDA.functional()) || ((@isdefined Metal) && Metal.functional()) - @test_throws "AssertionError: `Approximator` model is not on GPU." TargetNetwork(Approximator(model, optimiser), use_gpu=true) + @test_throws "AssertionError: `FluxModelApproximator` model is not on GPU." TargetNetwork(FluxModelApproximator(model, optimiser), use_gpu=true) end - @test TargetNetwork(Approximator(model=model, optimiser=optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork - @test TargetNetwork(Approximator(model, optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork + @test TargetNetwork(FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork + @test TargetNetwork(FluxModelApproximator(model, optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork - approx = Approximator(model, optimiser, use_gpu=false) + approx = FluxModelApproximator(model, optimiser, use_gpu=false) target_network = TargetNetwork(approx, use_gpu=false) @@ -25,7 +26,7 @@ using ReinforcementLearningCore @testset "Forward" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) - target_network = TargetNetwork(Approximator(model, Adam())) + target_network = TargetNetwork(FluxModelApproximator(model, Adam())) input = rand(Float32, 10) output = RLCore.forward(target_network, input) @@ -37,7 +38,7 @@ using ReinforcementLearningCore @testset "Optimise" begin optimiser = Adam() model = Chain(Dense(10, 5, relu), Dense(5, 2)) - approximator = Approximator(model, optimiser) + approximator = FluxModelApproximator(model, optimiser) target_network = TargetNetwork(approximator) input = rand(Float32, 10) grad = Flux.Zygote.gradient(target_network) do model @@ -53,7 +54,7 @@ using ReinforcementLearningCore @testset "Sync" begin optimiser = Adam() - model = Approximator(Chain(Dense(10, 5, relu), Dense(5, 2)), optimiser) + model = FluxModelApproximator(Chain(Dense(10, 5, relu), Dense(5, 2)), optimiser) target_network = TargetNetwork(model, sync_freq=2, ρ=0.5) input = rand(Float32, 10) @@ -72,7 +73,7 @@ end @testset "TargetNetwork" begin m = Chain(Dense(4,1)) - app = Approximator(model = m, optimiser = Flux.Adam(), use_gpu=true) + app = FluxModelApproximator(model = m, optimiser = Flux.Adam(), use_gpu=true) tn = TargetNetwork(app, sync_freq = 3, use_gpu=true) @test typeof(model(tn)) == typeof(target(tn)) p1 = Flux.destructure(model(tn))[1] diff --git a/src/ReinforcementLearningCore/test/policies/learners/td_learner.jl b/src/ReinforcementLearningCore/test/policies/learners/td_learner.jl index 3aae407e4..0379ad184 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/td_learner.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/td_learner.jl @@ -2,50 +2,54 @@ using Test using Flux @testset "Test TDLearner creation" begin - approximator = TabularVApproximator(n_state=5, opt = InvDecay(0.7)) + approximator = TabularVApproximator(n_state=5) @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner - approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) + approximator = TabularQApproximator(n_state=5, n_action=3) @test TDLearner(approximator, :SARS, γ=0.95, n=0) isa TDLearner end # Test TDLearner struct @testset "TDLearner struct" begin - approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) + approximator = TabularQApproximator(n_state=5, n_action=3) learner = TDLearner(approximator, :SARS) @test learner.approximator === approximator @test learner.γ == 1.0 @test learner.n == 0 end -# Test Q! function -@testset "Q! function" begin - approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) +# Test bellman_update! function +@testset "bellman_update! function" begin + learner = TDLearner(TabularQApproximator(n_state=5, n_action=3), :SARS) + approximator = learner.approximator s = 1 s_plus_one = 2 a = 3 - α = 1/(1 + approximator.optimiser_state.gamma) + α = learner.α π_ = 5.0 γ = 0.9 approximator.model[2, s_plus_one] = 15 approximator.model[a, s] = 2 # Following https://en.wikipedia.org/wiki/Q-learning#Algorithm - q_should_be = (1-α) * RLFarm.Q(approximator, s, a) + α * (π_ + γ * maximum(RLFarm.Q(approximator, s_plus_one))) + q_should_be = (1-α) * RLCore.Q(approximator, s, a) + α * (π_ + γ * maximum(RLCore.Q(approximator, s_plus_one))) - @test RLFarm.Q!(approximator, s, s_plus_one, a, π_, γ) ≈ q_should_be - @test RLFarm.Q(approximator, s, a) ≈ q_should_be + @test RLCore.bellman_update!(approximator, s, s_plus_one, a, π_, γ, α) ≈ q_should_be + @test RLCore.Q(approximator, s, a) ≈ q_should_be end # Test optimise! function @testset "optimise! function" begin - approximator = TabularQApproximator(n_state=5, n_action=3, opt = InvDecay(0.7)) - learner = TDLearner(approximator, :SARS) + learner = TDLearner(TabularQApproximator(n_state=5, n_action=3), :SARS) + approximator = learner.approximator t = (state=1, next_state=2, action=3, reward=5.0, terminal=false) optimise!(learner, t) - @test approximator.model[t.action, t.state] ≈ 2.9411764705882355 - for i in 1:500 + @test learner.approximator.model[t.action, t.state] ≈ 0.05 + optimise!(learner, t) + @test learner.approximator.model[t.action, t.state] ≈ 0.0995 + + for i in 1:1000 optimise!(learner, t) end @test approximator.model[t.action, t.state] ≈ t.reward atol=0.01 diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index 991096511..69e3c4777 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -1,7 +1,7 @@ @testset "QBasedPolicy" begin @testset "constructor" begin - q_approx = TabularQApproximator(n_state = 5, n_action = 10, opt = InvDecay(0.5)) + q_approx = TabularQApproximator(n_state = 5, n_action = 10) explorer = EpsilonGreedyExplorer(0.1) learner = TDLearner(q_approx, :SARS) p = QBasedPolicy(learner, explorer) @@ -12,7 +12,7 @@ @testset "plan!" begin @testset "plan! without player argument" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(learner, explorer) @@ -21,7 +21,7 @@ @testset "plan! with player argument" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(learner, explorer) @@ -33,7 +33,7 @@ # Test prob function @testset "prob" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(learner, explorer) @@ -50,7 +50,8 @@ ) t = (state=2, action=3) push!(trajectory, t) - t = (next_state=3, reward=5.0, terminal=false) + next_state = 4 + t = (action=3, state=next_state, reward=5.0, terminal=false) push!(trajectory, t) optimise!(policy, PostActStage(), trajectory) prob = RLBase.prob(policy, env) @@ -60,7 +61,7 @@ # Test optimise! function @testset "optimise!" begin env = TicTacToeEnv() - q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env)), opt = InvDecay(0.5)) + q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) explorer = EpsilonGreedyExplorer(0.1) learner = TDLearner(q_approx, :SARS) policy = QBasedPolicy(learner, explorer) diff --git a/src/ReinforcementLearningCore/test/utils/networks.jl b/src/ReinforcementLearningCore/test/utils/networks.jl index c995366ff..f070dc75c 100644 --- a/src/ReinforcementLearningCore/test/utils/networks.jl +++ b/src/ReinforcementLearningCore/test/utils/networks.jl @@ -5,7 +5,7 @@ import ReinforcementLearningBase: RLBase @testset "Approximators" begin #= These may need to be updated due to recent changes @testset "TabularApproximator" begin - A = TabularVApproximator(; n_state = 2, opt = InvDecay(1.0)) + A = TabularVApproximator(; n_state = 2) @test A(1) == 0.0 @test A(2) == 0.0 From 60326de9121ce2af4bb3517908549be8b1745675 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 17:44:32 +0100 Subject: [PATCH 20/22] add note --- .../src/policies/learners/td_learner.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl index 175c749ab..15cbaf3b9 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/td_learner.jl @@ -88,6 +88,6 @@ function RLBase.optimise!(learner::TDLearner, stage::AbstractStage, trajectory:: end end -# TDLearner{:SARS} optimises at the PostActStage +# TDLearner{:SARS} is optimized at the PostActStage RLBase.optimise!(learner::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(learner, trace) From a0330b8e55d0bcff695d52f80129ae49ce27405b Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Tue, 19 Mar 2024 18:30:17 +0100 Subject: [PATCH 21/22] add missing q-learning tests --- .../test/policies/q_based_policy.jl | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index 69e3c4777..cbe3382ab 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -55,7 +55,7 @@ push!(trajectory, t) optimise!(policy, PostActStage(), trajectory) prob = RLBase.prob(policy, env) - # Add assertions here + @test prob.p == [0.9111111111111111, 0.011111111111111112, 0.011111111111111112, 0.011111111111111112, 0.011111111111111112, 0.011111111111111112, 0.011111111111111112, 0.011111111111111112, 0.011111111111111112] end # Test optimise! function @@ -63,7 +63,7 @@ env = TicTacToeEnv() q_approx = TabularQApproximator(n_state = 5, n_action = length(action_space(env))) explorer = EpsilonGreedyExplorer(0.1) - learner = TDLearner(q_approx, :SARS) + learner = TDLearner(q_approx, :SARS, γ=0.95, α=0.01, n=0) policy = QBasedPolicy(learner, explorer) trajectory = Trajectory( CircularArraySARTSTraces(; @@ -76,13 +76,20 @@ DummySampler(), InsertSampleRatioController(), ) - t = (state=2, action=3) + t = (state=4, action=3) push!(trajectory, t) next_state = 4 t = (action=3, state=next_state, reward=5.0, terminal=false) push!(trajectory, t) - trajectory.container[1] + RLBase.optimise!(policy, PostActStage(), trajectory) - # Add assertions here + @test policy.learner.approximator.model[t.action, t.state] ≈ 0.05 + RLBase.optimise!(policy, PostActStage(), trajectory) + @test policy.learner.approximator.model[t.action, t.state] ≈ 0.09997500000000001 + + for i in 1:100000 + RLBase.optimise!(policy, PostActStage(), trajectory) + end + @test policy.learner.approximator.model[t.action, t.state] ≈ t.reward / (1-policy.learner.γ) atol=0.01 end end From fb61c6c02e0411f27197b75b67ce1e9f239e482a Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Wed, 20 Mar 2024 10:17:38 +0100 Subject: [PATCH 22/22] Update FluxModelApproximator to FluxApproximator --- .../policies/learners/flux_approximator.jl | 47 +++++++++++++++++++ .../learners/flux_model_approximator.jl | 47 ------------------- .../src/policies/learners/learners.jl | 2 +- .../src/policies/learners/target_network.jl | 18 +++---- ...l_approximator.jl => flux_approximator.jl} | 12 ++--- .../test/policies/learners/learners.jl | 2 +- .../test/policies/learners/target_network.jl | 16 +++---- 7 files changed, 72 insertions(+), 72 deletions(-) create mode 100644 src/ReinforcementLearningCore/src/policies/learners/flux_approximator.jl delete mode 100644 src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl rename src/ReinforcementLearningCore/test/policies/learners/{flux_model_approximator.jl => flux_approximator.jl} (74%) diff --git a/src/ReinforcementLearningCore/src/policies/learners/flux_approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/flux_approximator.jl new file mode 100644 index 000000000..02bc20087 --- /dev/null +++ b/src/ReinforcementLearningCore/src/policies/learners/flux_approximator.jl @@ -0,0 +1,47 @@ +export FluxApproximator + +using Flux + +""" + FluxApproximator(model, optimiser) + +Wraps a Flux trainable model and implements the `RLBase.optimise!(::FluxApproximator, ::Gradient)` +interface. See the RLCore documentation for more information on proper usage. +""" +struct FluxApproximator{M,O} <: AbstractLearner + model::M + optimiser_state::O +end + + +""" + FluxApproximator(; model, optimiser, usegpu=false) + +Constructs an `FluxApproximator` object for reinforcement learning. + +# Arguments +- `model`: The model used for approximation. +- `optimiser`: The optimizer used for updating the model. +- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`. + +# Returns +An `FluxApproximator` object. +""" +function FluxApproximator(; model, optimiser, use_gpu=false) + optimiser_state = Flux.setup(optimiser, model) + if use_gpu # Pass model to GPU (if available) upon creation + return FluxApproximator(gpu(model), gpu(optimiser_state)) + else + return FluxApproximator(model, optimiser_state) + end +end + +FluxApproximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = FluxApproximator(model=model, optimiser=optimiser, use_gpu=use_gpu) + +Flux.@layer FluxApproximator trainable=(model,) + +forward(A::FluxApproximator, args...; kwargs...) = A.model(args...; kwargs...) +forward(A::FluxApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x)) + +RLBase.optimise!(A::FluxApproximator, grad::NamedTuple) = + Flux.Optimise.update!(A.optimiser_state, A.model, grad.model) diff --git a/src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl b/src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl deleted file mode 100644 index 2e203e6ac..000000000 --- a/src/ReinforcementLearningCore/src/policies/learners/flux_model_approximator.jl +++ /dev/null @@ -1,47 +0,0 @@ -export FluxModelApproximator - -using Flux - -""" - FluxModelApproximator(model, optimiser) - -Wraps a Flux trainable model and implements the `RLBase.optimise!(::FluxModelApproximator, ::Gradient)` -interface. See the RLCore documentation for more information on proper usage. -""" -struct FluxModelApproximator{M,O} <: AbstractLearner - model::M - optimiser_state::O -end - - -""" - FluxModelApproximator(; model, optimiser, usegpu=false) - -Constructs an `FluxModelApproximator` object for reinforcement learning. - -# Arguments -- `model`: The model used for approximation. -- `optimiser`: The optimizer used for updating the model. -- `usegpu`: A boolean indicating whether to use GPU for computation. Default is `false`. - -# Returns -An `FluxModelApproximator` object. -""" -function FluxModelApproximator(; model, optimiser, use_gpu=false) - optimiser_state = Flux.setup(optimiser, model) - if use_gpu # Pass model to GPU (if available) upon creation - return FluxModelApproximator(gpu(model), gpu(optimiser_state)) - else - return FluxModelApproximator(model, optimiser_state) - end -end - -FluxModelApproximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=use_gpu) - -Flux.@layer FluxModelApproximator trainable=(model,) - -forward(A::FluxModelApproximator, args...; kwargs...) = A.model(args...; kwargs...) -forward(A::FluxModelApproximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x)) - -RLBase.optimise!(A::FluxModelApproximator, grad::NamedTuple) = - Flux.Optimise.update!(A.optimiser_state, A.model, grad.model) diff --git a/src/ReinforcementLearningCore/src/policies/learners/learners.jl b/src/ReinforcementLearningCore/src/policies/learners/learners.jl index b925f58ec..0290ea7bd 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/learners.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/learners.jl @@ -1,5 +1,5 @@ include("abstract_learner.jl") -include("flux_model_approximator.jl") +include("flux_approximator.jl") include("tabular_approximator.jl") include("td_learner.jl") include("target_network.jl") diff --git a/src/ReinforcementLearningCore/src/policies/learners/target_network.jl b/src/ReinforcementLearningCore/src/policies/learners/target_network.jl index 26e32b6cc..7a3b8490a 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/target_network.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/target_network.jl @@ -2,13 +2,13 @@ export TargetNetwork, target, model using Flux -target(ap::FluxModelApproximator) = ap.model #see TargetNetwork -model(ap::FluxModelApproximator) = ap.model #see TargetNetwork +target(ap::FluxApproximator) = ap.model #see TargetNetwork +model(ap::FluxApproximator) = ap.model #see TargetNetwork """ - TargetNetwork(network::FluxModelApproximator; sync_freq::Int = 1, ρ::Float32 = 0f0) + TargetNetwork(network::FluxApproximator; sync_freq::Int = 1, ρ::Float32 = 0f0) -Wraps an FluxModelApproximator to hold a target network that is updated towards the model of the +Wraps an FluxApproximator to hold a target network that is updated towards the model of the approximator. - `sync_freq` is the number of updates of `network` between each update of the `target`. - ρ (\rho) is "how much of the target is kept when updating it". @@ -21,11 +21,11 @@ Implements the `RLBase.optimise!(::TargetNetwork, ::Gradient)` interface to upda and the target with weights replacement or Polyak averaging. Note to developers: `model(::TargetNetwork)` will return the trainable Flux model -and `target(::TargetNetwork)` returns the target model and `target(::FluxModelApproximator)` +and `target(::TargetNetwork)` returns the target model and `target(::FluxApproximator)` returns the non-trainable Flux model. See the RLCore documentation. """ mutable struct TargetNetwork{M} - network::FluxModelApproximator{M} + network::FluxApproximator{M} target::M sync_freq::Int ρ::Float32 @@ -46,13 +46,13 @@ Constructs a target network for reinforcement learning. # Returns A `TargetNetwork` object. """ -function TargetNetwork(network::FluxModelApproximator; sync_freq = 1, ρ = 0f0, use_gpu = false) +function TargetNetwork(network::FluxApproximator; sync_freq = 1, ρ = 0f0, use_gpu = false) @assert 0 <= ρ <= 1 "ρ must in [0,1]" ρ = Float32(ρ) if use_gpu - @assert typeof(gpu(network.model)) == typeof(network.model) "`FluxModelApproximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `FluxModelApproximator`." - # NOTE: model is pushed to gpu in FluxModelApproximator, need to transfer to cpu before deepcopy, then push target model to gpu + @assert typeof(gpu(network.model)) == typeof(network.model) "`FluxApproximator` model is not on GPU. Please set `use_gpu=false`` or ensure model is on GPU, by setting `use_gpu=true` when constructing `FluxApproximator`." + # NOTE: model is pushed to gpu in FluxApproximator, need to transfer to cpu before deepcopy, then push target model to gpu target = gpu(deepcopy(cpu(network.model))) else target = deepcopy(network.model) diff --git a/src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl b/src/ReinforcementLearningCore/test/policies/learners/flux_approximator.jl similarity index 74% rename from src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl rename to src/ReinforcementLearningCore/test/policies/learners/flux_approximator.jl index fd5359d46..b5fd19749 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/flux_model_approximator.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/flux_approximator.jl @@ -1,13 +1,13 @@ using Test using Flux -@testset "FluxModelApproximator Tests" begin +@testset "FluxApproximator Tests" begin @testset "Creation, with use_gpu = true toggle" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=true) + approximator = FluxApproximator(model=model, optimiser=optimiser, use_gpu=true) - @test approximator isa FluxModelApproximator + @test approximator isa FluxApproximator @test typeof(approximator.model) == typeof(gpu(model)) @test approximator.optimiser_state isa NamedTuple end @@ -15,7 +15,7 @@ using Flux @testset "Forward" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=false) + approximator = FluxApproximator(model=model, optimiser=optimiser, use_gpu=false) input = rand(Float32, 10) output = RLCore.forward(approximator, input) @@ -27,7 +27,7 @@ using Flux @testset "Forward to environment" begin model = Chain(Dense(4, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=false) + approximator = FluxApproximator(model=model, optimiser=optimiser, use_gpu=false) env = CartPoleEnv(T=Float32) output = RLCore.forward(approximator, env) @@ -38,7 +38,7 @@ using Flux @testset "Optimise" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() - approximator = FluxModelApproximator(model=model, optimiser=optimiser) + approximator = FluxApproximator(model=model, optimiser=optimiser) input = rand(Float32, 10) diff --git a/src/ReinforcementLearningCore/test/policies/learners/learners.jl b/src/ReinforcementLearningCore/test/policies/learners/learners.jl index 9d0c474d6..07deb8f48 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/learners.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/learners.jl @@ -1,6 +1,6 @@ @testset "approximators.jl" begin include("abstract_learner.jl") - include("flux_model_approximator.jl") + include("flux_approximator.jl") include("tabular_approximator.jl") include("target_network.jl") include("td_learner.jl") diff --git a/src/ReinforcementLearningCore/test/policies/learners/target_network.jl b/src/ReinforcementLearningCore/test/policies/learners/target_network.jl index 564e92f04..e9182ddaa 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/target_network.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/target_network.jl @@ -7,12 +7,12 @@ using ReinforcementLearningCore model = Chain(Dense(10, 5, relu), Dense(5, 2)) optimiser = Adam() if ((@isdefined CUDA) && CUDA.functional()) || ((@isdefined Metal) && Metal.functional()) - @test_throws "AssertionError: `FluxModelApproximator` model is not on GPU." TargetNetwork(FluxModelApproximator(model, optimiser), use_gpu=true) + @test_throws "AssertionError: `FluxApproximator` model is not on GPU." TargetNetwork(FluxApproximator(model, optimiser), use_gpu=true) end - @test TargetNetwork(FluxModelApproximator(model=model, optimiser=optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork - @test TargetNetwork(FluxModelApproximator(model, optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork + @test TargetNetwork(FluxApproximator(model=model, optimiser=optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork + @test TargetNetwork(FluxApproximator(model, optimiser, use_gpu=true), use_gpu=true) isa TargetNetwork - approx = FluxModelApproximator(model, optimiser, use_gpu=false) + approx = FluxApproximator(model, optimiser, use_gpu=false) target_network = TargetNetwork(approx, use_gpu=false) @@ -26,7 +26,7 @@ using ReinforcementLearningCore @testset "Forward" begin model = Chain(Dense(10, 5, relu), Dense(5, 2)) - target_network = TargetNetwork(FluxModelApproximator(model, Adam())) + target_network = TargetNetwork(FluxApproximator(model, Adam())) input = rand(Float32, 10) output = RLCore.forward(target_network, input) @@ -38,7 +38,7 @@ using ReinforcementLearningCore @testset "Optimise" begin optimiser = Adam() model = Chain(Dense(10, 5, relu), Dense(5, 2)) - approximator = FluxModelApproximator(model, optimiser) + approximator = FluxApproximator(model, optimiser) target_network = TargetNetwork(approximator) input = rand(Float32, 10) grad = Flux.Zygote.gradient(target_network) do model @@ -54,7 +54,7 @@ using ReinforcementLearningCore @testset "Sync" begin optimiser = Adam() - model = FluxModelApproximator(Chain(Dense(10, 5, relu), Dense(5, 2)), optimiser) + model = FluxApproximator(Chain(Dense(10, 5, relu), Dense(5, 2)), optimiser) target_network = TargetNetwork(model, sync_freq=2, ρ=0.5) input = rand(Float32, 10) @@ -73,7 +73,7 @@ end @testset "TargetNetwork" begin m = Chain(Dense(4,1)) - app = FluxModelApproximator(model = m, optimiser = Flux.Adam(), use_gpu=true) + app = FluxApproximator(model = m, optimiser = Flux.Adam(), use_gpu=true) tn = TargetNetwork(app, sync_freq = 3, use_gpu=true) @test typeof(model(tn)) == typeof(target(tn)) p1 = Flux.destructure(model(tn))[1]