Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add similar/setindex!, update argmax/argmin #10

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ end
Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :))
Base.getindex(x::OneHotArray{<:Any, <:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x

function Base.similar(::OneHotArray{T, L}, ::Type{Bool}, dims::Dims) where {T, L}
if first(dims) == L
indices = ones(T, Base.tail(dims))
return OneHotArray(indices, first(dims))
else
return BitArray(undef, dims)
Comment on lines +50 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say what this similar method is for?

It seems surprising to me to return something which isn't writable. Elsewhere the pattern is that the method which takes a size returns a full matrix, but without the size, a structured one:

julia> similar(Diagonal(1:3), Float32)
3×3 Diagonal{Float32, Vector{Float32}}:
 0.0   ⋅    ⋅ 
  ⋅   0.0   ⋅ 
  ⋅    ⋅   2.09389f-37

julia> similar(Diagonal(1:3), Float32, (3,3))
3×3 Matrix{Float32}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

end
end

function Base.setindex!(x::OneHotLike{<:Any, <:Any, N}, v::Bool, i::Integer, I::Vararg{Integer, N}) where N
@boundscheck checkbounds(x, i, I...)
if v
_indices(x)[I...] = i
else
error("OneHotArray cannot be set with false values")
end
end

function Base.showarg(io::IO, x::OneHotArray, toplevel)
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")
Base.showarg(io, x.indices, false)
Expand All @@ -69,6 +87,30 @@ Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N,
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

_onehot_compatible(x::OneHotLike) = _isonehot(x)
_onehot_compatible(x::AbstractVector{Bool}) = count(x) == 1
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly tidier, maybe:

Suggested change
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1))
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, count(x; dims=1))

_onehot_compatible(x::AbstractArray) = _onehot_compatible(BitArray(x))

function OneHotArray(x::OneHotLike)
!_onehot_compatible(x) && error("Array is not onehot compatible")
return x
end

function OneHotArray(x::AbstractVector)
!_onehot_compatible(x) && error("Array is not onehot compatible")
Comment on lines +100 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What arrays do you want this to accept, which aren't AbstractVector{Bool}?

And could "not onehot compatible" explain a bit more what it means by compatible?

return OneHotArray(findfirst(x), length(x))
end

function OneHotArray(x::AbstractArray)
!_onehot_compatible(x) && error("Array is not onehot compatible")
dims = size(x)
dim1, dim2 = dims[1], reduce(*, Base.tail(dims))
rx = reshape(x, (dim1, dim2))
indices = UInt32[findfirst(==(true), col) for col in eachcol(rx)]
return OneHotArray(reshape(indices, Base.tail(dims)), dim1)
end

function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
Expand All @@ -80,11 +122,9 @@ end
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1)

# optimized concatenation for matrices and vectors of same parameters
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)
# optimized concatenation for arrays of same parameters
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any}} =
OneHotArray(reduce(vcat, _indices.(xs); init = _indices(x)), L)

MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)

Expand All @@ -94,7 +134,21 @@ Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
function Base.argmax(x::OneHotLike; dims = Colon())
if _isonehot(x) && dims == 1
cart_inds = CartesianIndex.(_indices(x), CartesianIndices(_indices(x)))
return reshape(cart_inds, (1, size(_indices(x))...))
else
return argmax(BitArray(x); dims=dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is copying here better than invokeing the more generic method?

end
end

function Base.argmin(x::OneHotLike; dims = Colon())
if _isonehot(x) && dims == 1
labelargs = ifelse.(_indices(x) .== 1, 2, 1)
cart_inds = CartesianIndex.(labelargs, CartesianIndices(_indices(x)))
return reshape(cart_inds, (1, size(_indices(x))...))
else
return argmin(BitArray(x); dims=dims)
end
end
18 changes: 18 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ end
@test_throws BoundsError oa[:, :]
end

@testset "Converting" begin
compat_arr = BitArray(OneHotArray(rand(1:5, (3, 5)), 5))

@test_throws Exception OneHotArray([1 0 0; 0 1 0])
@test OneHotArray(compat_arr) == compat_arr
@test OneHotArray(oa) === oa
end

@testset "Concatenating" begin
# vector cat
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
Expand Down Expand Up @@ -101,6 +109,16 @@ end
@test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3)
end

@testset "Base.argmin" begin
# argmin test
@test argmin(ov) == argmin(convert(Array{Bool}, ov))
@test argmin(om) == argmin(convert(Array{Bool}, om))
@test argmin(om; dims = 1) == argmin(convert(Array{Bool}, om); dims = 1)
@test argmin(om; dims = 2) == argmin(convert(Array{Bool}, om); dims = 2)
@test argmin(oa; dims = 1) == argmin(convert(Array{Bool}, oa); dims = 1)
@test argmin(oa; dims = 3) == argmin(convert(Array{Bool}, oa); dims = 3)
end

@testset "Forward map to broadcast" begin
@test map(identity, oa) == oa
@test map(x -> 2 * x, oa) == 2 .* oa
Expand Down
18 changes: 10 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
using OneHotArrays
using Test

@testset "OneHotArray" begin
include("array.jl")
end
@testset verbose=true "OneHotArrays" begin
@testset "Array" begin
include("array.jl")
end

@testset "Constructors" begin
include("onehot.jl")
end
@testset "Constructors" begin
include("onehot.jl")
end

@testset "Linear Algebra" begin
include("linalg.jl")
@testset "Linear Algebra" begin
include("linalg.jl")
end
end