Created
November 19, 2014 17:53
-
-
Save Jutho/85982d26d63842d71749 to your computer and use it in GitHub Desktop.
Cartesian Iteration with correction for N=0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Multidimensional iterators | |
module IteratorsMD | |
import Base: start, _start, done, next, getindex, setindex!, linearindexing | |
import Base: @nref, @ncall, @nif, @nexprs, LinearFast, LinearSlow | |
export eachindex | |
# Traits for linear indexing | |
linearindexing(::BitArray) = LinearFast() | |
# Iterator/state | |
abstract CartesianIndex{N} # the state for all multidimensional iterators | |
abstract IndexIterator{N} # Iterator that visits the index associated with each element | |
stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int}) | |
indextype, itertype = gen_cartesian(N) | |
return :($indextype(index)) | |
end | |
stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int}) | |
indextype,itertype=gen_cartesian(N) | |
return :($itertype(index)) | |
end | |
# indexing | |
stagedfunction getindex{T,N}(A::AbstractArray{T,N}, index::CartesianIndex{N}) | |
:(@nref $N A d->getfield(index,d)) | |
end | |
stagedfunction setindex!{T,N}(A::AbstractArray{T,N}, v, index::CartesianIndex{N}) | |
:((@nref $N A d->getfield(index,d)) = v) | |
end | |
# iteration | |
eachindex(A::AbstractArray) = IndexIterator(size(A)) | |
stagedfunction start{N}(iter::IndexIterator{N}) | |
indextype, _ = gen_cartesian(N) | |
args = fill(:s, N) | |
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:M]...) | |
quote | |
s = ifelse($anyzero, typemax(Int), 1) | |
return false, $indextype($(args...)) | |
end | |
end | |
stagedfunction _start{T,N}(A::AbstractArray{T,N},::LinearSlow) | |
indextype, _ = gen_cartesian(N) | |
args = fill(:s, N) | |
quote | |
s = ifelse(isempty(A), typemax(Int), 1) | |
return false, $indextype($(args...)) | |
end | |
end | |
gen_cartesian(1) # to make sure the next two lines are valid | |
next(R::StepRange, state::(Bool, CartesianIndex{1})) = R[state[2].I_1], (state[2].I_1==length(R), CartesianIndex_1(state[2].I_1+1)) | |
next{T}(R::UnitRange{T}, state::(Bool, CartesianIndex{1})) = R[state[2].I_1], (state[2].I_1==length(R), CartesianIndex_1(state[2].I_1+1)) | |
stagedfunction next{T,N}(A::AbstractArray{T,N}, state::(Bool, CartesianIndex{N})) | |
indextype, _ = gen_cartesian(N) | |
finishedex = (N==0 ? true : :(getfield(newindex, $N) > size(A, $N))) | |
meta = Expr(:meta, :inline) | |
quote | |
$meta | |
index=state[2] | |
@inbounds v = A[index] | |
newindex=@nif $N d->(getfield(state,d) < size(A, d)) d->@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)) | |
finished=$finishedex | |
v, (finished,newindex) | |
end | |
end | |
stagedfunction next{N}(iter::IndexIterator{N}, state::(Bool, CartesianIndex{N})) | |
indextype, _ = gen_cartesian(N) | |
finishedex = (N==0 ? true : :(getfield(newindex, $N) > getfield(iter.dims, $N))) | |
meta = Expr(:meta, :inline) | |
quote | |
$meta | |
index=state[2] | |
newindex=@nif $N d->(getfield(index,d) < getfield(iter.dims, d)) d->@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)) | |
finished=$finishedex | |
index, (finished,newindex) | |
end | |
end | |
done(R::StepRange, state::(Bool, CartesianIndex{1})) = state[1] | |
done(R::UnitRange, state::(Bool, CartesianIndex{1})) = state[1] | |
done(R::FloatRange, state::(Bool, CartesianIndex{1})) = state[1] | |
done{T,N}(A::AbstractArray{T,N}, state::(Bool, CartesianIndex{N})) = state[1] | |
done{N}(iter::IndexIterator{N}, state::(Bool, CartesianIndex{N})) = state[1] | |
let implemented = IntSet() | |
global gen_cartesian | |
function gen_cartesian(N::Int) | |
# Create the types | |
indextype = symbol("CartesianIndex_$N") | |
itertype = symbol("IndexIterator_$N") | |
if !in(N,implemented) | |
fieldnames = [symbol("I_$i") for i = 1:N] | |
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N] | |
extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...)) | |
exindices = Expr[:(index[$i]) for i = 1:N] | |
index_tuple_constr = N > 0 ? | |
(:($indextype(index::NTuple{$N,Int}) = $indextype($(exindices...)))) : | |
(:($indextype(index::NTuple{0,Int}) = $indextype(1))) | |
totalex = quote | |
# type definition of state | |
$extype | |
# constructor from tuple | |
$index_tuple_constr | |
# type definition of iterator | |
immutable $itertype <: IndexIterator{$N} | |
dims::$indextype | |
end | |
# constructor from tuple | |
$itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims)) | |
end | |
eval(totalex) | |
push!(implemented,N) | |
end | |
return indextype, itertype | |
end | |
end | |
end # IteratorsMD |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment