Skip to content

Instantly share code, notes, and snippets.

@willtebbutt
Created January 24, 2025 15:41
Show Gist options
  • Save willtebbutt/16bce3962aa3e9afc11465d0073ccdda to your computer and use it in GitHub Desktop.
Save willtebbutt/16bce3962aa3e9afc11465d0073ccdda to your computer and use it in GitHub Desktop.
"""
successor_count(b::CC.BasicBlock)::Int
Returns the total number of successors of `b`.
"""
successor_count(b::CC.BasicBlock)::Int = length(b.succs)
"""
in_block(n::SSAValue, b::CC.BasicBlock)
`true` if statment `n` is contained in block `b`.
"""
in_block(n::SSAValue, b::CC.BasicBlock)::Bool = in(n.id, b.stmts)
"""
edge_count(blocks::Vector{CC.BasicBlock})::Int
Returns the total number of edges in `cfg`.
"""
edge_count(blocks::Vector{CC.BasicBlock})::Int = sum(successor_count, blocks)
"""
Edge(tail::Int, head::Int)
An _directed_ edge pointing from `tail` to `head`.
"""
struct Edge
tail::Int
head::Int
end
"""
edges(blocks::Vector{CC.BasicBlock})::Vector{Edge}
Returns all edges in `blocks`. Provides no ordering guarantees.
"""
function edges(blocks::Vector{CC.BasicBlock})::Vector{Edge}
_edges = Vector{Edge}(undef, edge_count(blocks))
position = 1
for (blk_id, blk) in enumerate(blocks)
for successor_id in blk.succs
@inbounds _edges[position] = Edge(blk_id, successor_id)
position += 1
end
end
return _edges
end
"""
backedges(dom_tree::CC.DomTree, blocks::Vector{CC.BasicBlock})::Vector{Edge}
Returns all edges in `blocks` which are _backedges_. An edge is a backedge if its head
dominates its tail.
The `dom_tree` can be constructed using `CC.construct_domtree(blocks)`.
"""
function backedges(dom_tree::CC.DomTree, blocks::Vector{CC.BasicBlock})::Vector{Edge}
return filter(e -> CC.dominates(dom_tree, e.head, e.tail), edges(blocks))
end
"""
NaturalLoop(backedge::Edge, block::Vector{Int})
Represents a `NaturalLoop`, see e.g. https://web.cs.wpi.edu/~kal/PLT/PLT8.6.4.html .
Contains the defining `backedge` (a directed edge from `tail` to `head` such that `head`
dominates `tail`), and the indices of all basic blocks associated to this natural loop.
"""
struct NaturalLoop
backedge::Edge
blocks::Vector{Int}
end
function Base.:(==)(x::NaturalLoop, y::NaturalLoop)
return (x.backedge == y.backedge) && (x.blocks == y.blocks)
end
"""
natural_loops(
blocks::Vector{CC.BasicBlock}, dom_tree::CC.DomTree=CC.construct_domtree(blocks),
)::Vector{NaturalLoop}
Discover all of the natural loops present in `blocks`.
`dom_tree` can be constructed using `CC.construct_domtree(blocks)`.
"""
function natural_loops(
blocks::Vector{CC.BasicBlock}, dom_tree::CC.DomTree=CC.construct_domtree(blocks),
)
# Allocate working memory which is re-used for all natural loops.
work_stack = Stack{Int}()
sizehint!(work_stack.memory, edge_count(blocks))
is_discovered = fill(false, length(blocks))
# Construct the natural loop associated to each backedge `b` in `cfg`. The definition is
# recursive: a natural loop is the collection of blocks dominated by `b.head` who have
# no predecessors outside of the natural loop (except `b.head`, which is the entry point
# of the loop, and will therefore definitely have a predecessor coming in from outside
# the loop unless `b.head` is the first block in `cfg`.), and are a predecessor of
# b.tail.
#
# This is achieved with a depth-first search (DFS)-style procedure -- it's perhaps best
# thought of as a reversed version of DFS, because we add predecessors to the work list,
# rather than successors.
# In this algorithm, we start at `b.tail`, and perform a DFS of the predecessors of each
# basic block. This search never moves past `b.head`. This algorithm is correct because
# `b.head dom b.tail` implies that each predecessor of `b.tail` is either `b.head`
# itself, or is dominated by `b.head`. by induction, each predecessor of the
# predecessors of `b.tail` are either `b.head` or dominated by `b.head`, etc.
return map(backedges(dom_tree, blocks)) do backedge
# Initialise data structures. Initialise header to discovered.
is_discovered .= false
is_discovered[backedge.head] = true
is_discovered[backedge.tail] || push!(work_stack, backedge.tail)
loop_blocks = Int[backedge.head]
# Perform DFS.
while !isempty(work_stack)
# pop an item from the work stack, and push it to the collection of blocks
# listed as being part of the natural loop. Mark it as "discovered" so that we
# don't revisit it if it's the predecessor of another block in the loop.
s = pop!(work_stack)
push!(loop_blocks, s)
is_discovered[s] = true
# DFS-style work-stack push operation -- push all predecessors onto the stack.
for p in blocks[s].preds
is_discovered[p] || push!(work_stack, p)
end
end
# Push the header into the loop blocks, sort them, and return.
return NaturalLoop(backedge, sort!(loop_blocks))
end
end
"""
in_loop(n::SSAValue, blocks::Vector{CC.BasicBlock}, loop::NaturalLoop)::Bool
`true` if `n` is in one of the block in the `loop`.
"""
function in_loop(n::SSAValue, blocks::Vector{CC.BasicBlock}, loop::NaturalLoop)::Bool
return any(blk_ind -> in_block(n, blocks[blk_ind]), loop.blocks)
end
"""
loop_invariant_ssas(ir::IRCode, loop::NaturalLoop)::Vector{Bool}
The returned vector has length equal to the number of statements in `ir`. The nth element is
`true` if the nth statement in `ir` is a loop invariant for `loop`. If the nth stmt
cannot be proven to be a loop invariant, then the nth element is `false`. This is a
conservative approximation -- we cannot identify all situation in which a value is loop
invariant.
# Note
This analysis can be improved by either running LICM prior to identifying loop invariants
as we currently do (is an argument outside of the natural loop or not). It could also be
improved by asking whether a call is both `:effect_free` by calling `Base.infer_effects`,
and if the arguments to the call are themselves loop invariants.
"""
function loop_invariant_ssas(ir::IRCode, loop::NaturalLoop)::Vector{Bool}
# Initialise everything to false.
is_loop_invariant = fill(false, length(ir.stmts))
# Construct a vector containing indices of all statements in `loop`.
blocks = ir.cfg.blocks
loop_stmt_ids = mapreduce(blk_id -> blocks[blk_id].stmts, vcat, loop.blocks)
# Construct a vector whose nth element is `true` if the nth stmt in `ir` is in the loop.
is_in_loop = fill(false, length(ir.stmts))
is_in_loop[loop_stmt_ids] .= true
# For each call / invoke expression in the loop, look at its arguments for loop
# invariants.
for stmt_id in loop_stmt_ids
# Pull out the arguments of interest.
inst = ir.stmts.inst[stmt_id]
args = (Meta.isexpr(inst, :call) || Meta.isexpr(inst, :invoke)) ? inst.args : Any[]
# For each argument, determine whether it is a loop invariant or not.
for n in 1:length(args)
arg = inst.args[n]
if arg isa SSAValue && !is_in_loop[arg.id]
is_loop_invariant[arg.id] = true
end
end
end
return is_loop_invariant
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment