Created
August 31, 2018 23:58
-
-
Save zsunberg/d4013bd63c71352e6e8269a196c08f1b to your computer and use it in GitHub Desktop.
Grid world benchmark showing that the current julia compiler cannot handle multiple state types. Output for julia 1.0 at bottom.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using POMDPs | |
using POMDPModelTools | |
using POMDPSimulators | |
using POMDPPolicies | |
using StaticArrays | |
using Parameters | |
using Random | |
using BenchmarkTools | |
using POMDPModels | |
using Test | |
# Common | |
const GWPos = SVector{2,Int} | |
abstract type AbstractGridWorld{S} <: MDP{S, Symbol} end | |
const dir = Dict(:up=>GWPos(0,1), :down=>GWPos(0,-1), :left=>GWPos(-1,0), :right=>GWPos(1,0)) | |
const aind = Dict(:up=>1, :down=>2, :left=>3, :right=>4) | |
POMDPs.actions(mdp::AbstractGridWorld) = (:up, :down, :left, :right) | |
Base.rand(rng::AbstractRNG, t::Tuple) = t[rand(rng, 1:length(t))] # don't know why this doesn't work out of the box | |
POMDPs.n_states(mdp::AbstractGridWorld) = prod(mdp.size) + 1 | |
POMDPs.n_actions(mdp::AbstractGridWorld) = 4 | |
POMDPs.discount(mdp::AbstractGridWorld) = mdp.discount | |
POMDPs.actionindex(mdp::AbstractGridWorld, a::Symbol) = aind[a] | |
POMDPs.reward(mdp::AbstractGridWorld, s::GWPos, a::Symbol) = get(mdp.rewards, s, 0.0) | |
POMDPs.initialstate(mdp::AbstractGridWorld, rng::AbstractRNG) = GWPos(rand(rng, 1:mdp.size[1]), rand(rng, 1:mdp.size[2])) | |
# attempts to eliminate extraneous allocations | |
# @inline clamp2(v::GWPos, l, u) = GWPos(clamp(v[1], l[1], u[1]), clamp(v[2], l[2], u[2])) | |
# function neighbors(mdp::AbstractGridWorld, s) | |
# return (GWPos(s[1], min(s[2]+1, mdp.size[2])), # up | |
# GWPos(s[1], max(s[2]-1, 1)), # down | |
# GWPos(max(s[1]-1, 1), s[1]), # left | |
# GWPos(min(s[1]+1, mdp.size[1]), s[2]) | |
# ) | |
# end | |
###################################### | |
# Simple version using TerminalState # | |
###################################### | |
const StateTypes = Union{GWPos, TerminalState} | |
@with_kw struct SimpleGridWorld <: AbstractGridWorld{StateTypes} | |
size::Tuple{Int, Int} = (10,10) | |
rewards::Dict{GWPos, Float64} = Dict(GWPos(4,3)=>-10.0, GWPos(4,6)=>-5.0, GWPos(9,3)=>10.0, GWPos(8,8)=>3.0) | |
terminate_in::Set{GWPos} = Set((GWPos(4,3), GWPos(4,6), GWPos(9,3), GWPos(8,8))) | |
tprob::Float64 = 0.7 | |
discount::Float64 = 0.95 | |
end | |
function POMDPs.states(mdp::SimpleGridWorld) | |
ss = vec(StateTypes[GWPos(x, y) for x in 1:mdp.size[1], y in mdp.size[2]]) | |
push!(ss, terminalstate) | |
return ss | |
end | |
POMDPs.stateindex(mdp::SimpleGridWorld, s::GWPos) = LinearIndices(mdp.size)[s...] | |
POMDPs.stateindex(mdp::SimpleGridWorld, s::TerminalState) = prod(mdp.size) + 1 | |
function POMDPs.transition(mdp::SimpleGridWorld, s::GWPos, a::Symbol) | |
if s in mdp.terminate_in | |
return Deterministic(terminalstate) | |
end | |
## This causes allocations | |
# neighbors = map(actions(mdp)) do act | |
# clamp2(s+dir[act], (1,1), mdp.size) # clamp out of bounds to inbounds | |
# end | |
neighbors = map(actions(mdp)) do act | |
s + dir[act] | |
end | |
probs = map(actions(mdp)) do act | |
if act == a | |
return mdp.tprob # probability of transitioning to the desired cell | |
else | |
return (1.0 - mdp.tprob)/3 # probability of transitioning to another cell | |
end | |
end | |
return SparseCat(neighbors, probs) | |
end | |
####################### | |
# Type-stable version # | |
####################### | |
const tv = GWPos(-1,-1) | |
@with_kw struct SimpleTypeStableGridWorld <: AbstractGridWorld{GWPos} | |
size::Tuple{Int, Int} = (10,10) | |
rewards::Dict{GWPos, Float64} = Dict(GWPos(4,3)=>-10.0, GWPos(4,6)=>-5.0, GWPos(9,3)=>10.0, GWPos(8,8)=>3.0) | |
terminate_in::Set{GWPos} = Set((GWPos(4,3), GWPos(4,6), GWPos(9,3), GWPos(8,8))) | |
tprob::Float64 = 0.7 | |
discount::Float64 = 0.95 | |
end | |
function POMDPs.states(mdp::SimpleTypeStableGridWorld) | |
ss = vec(GWPos[GWPos(x, y) for x in 1:mdp.size[1], y in mdp.size[2]]) | |
push!(ss, GWPos(-1,-1)) | |
return ss | |
end | |
function POMDPs.stateindex(mdp::SimpleTypeStableGridWorld, s::GWPos) | |
if all(s.>0) | |
return LinearIndices(mdp.size)[s...] | |
else | |
return prod(mdp.size + 1) | |
end | |
end | |
function POMDPs.transition(mdp::SimpleTypeStableGridWorld, s::GWPos, a::Symbol) | |
if s in mdp.terminate_in | |
return SparseCat((tv, tv, tv, tv), (1.0, 0.0, 0.0, 0.0)) | |
end | |
## This causes allocations | |
# neighbors = map(actions(mdp)) do act | |
# clamp2(s+dir[act], (1,1), mdp.size) # clamp out of bounds to inbounds | |
# end | |
neighbors = map(actions(mdp)) do act | |
s + dir[act] | |
end | |
probs = map(actions(mdp)) do act | |
if act == a | |
return mdp.tprob # probability of transitioning to the desired cell | |
else | |
return (1.0 - mdp.tprob)/3 # probability of transitioning to another cell | |
end | |
end | |
return SparseCat(neighbors, probs) | |
end | |
POMDPs.isterminal(::SimpleTypeStableGridWorld, s::GWPos) = any(s.<0) | |
#################### | |
# Benchmark Script # | |
#################### | |
mdps = [GridWorld(terminals = Set()), | |
SimpleGridWorld(terminate_in = Set()), | |
SimpleTypeStableGridWorld(terminate_in = Set()) | |
] | |
@inferred transition(mdps[3], GWPos(1,1), :up) | |
for m in mdps | |
@show typeof(m) | |
policy = RandomPolicy(m, rng=MersenneTwister(7)) | |
rosim = RolloutSimulator(max_steps=10_000, rng=MersenneTwister(2)) | |
@btime simulate($rosim, $m, $policy) | |
end | |
############ | |
## OUTPUT ## | |
############ | |
#= | |
julia> include("gw_bench.jl") | |
WARNING: redefining constant dir | |
WARNING: redefining constant aind | |
typeof(m) = GridWorld | |
3.265 ms (40000 allocations: 3.05 MiB) | |
typeof(m) = SimpleGridWorld | |
100.529 ms (436593 allocations: 12.55 MiB) | |
typeof(m) = SimpleTypeStableGridWorld | |
12.241 μs (73 allocations: 2.09 KiB) | |
=# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment