Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add similar/setindex!, update argmax/argmin #10

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TLipede
Copy link
Contributor

@TLipede TLipede commented Mar 11, 2022

This addresses #6. Also adds some conversion methods, and expands the arrays that the fast hcat should work with (as far as I can tell this should still be valid).

Comment on lines +50 to +55
function Base.similar(::OneHotArray{T, L}, ::Type{Bool}, dims::Dims) where {T, L}
if first(dims) == L
indices = ones(T, Base.tail(dims))
return OneHotArray(indices, first(dims))
else
return BitArray(undef, dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say what this similar method is for?

It seems surprising to me to return something which isn't writable. Elsewhere the pattern is that the method which takes a size returns a full matrix, but without the size, a structured one:

julia> similar(Diagonal(1:3), Float32)
3×3 Diagonal{Float32, Vector{Float32}}:
 0.0   ⋅    ⋅ 
  ⋅   0.0   ⋅ 
  ⋅    ⋅   2.09389f-37

julia> similar(Diagonal(1:3), Float32, (3,3))
3×3 Matrix{Float32}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

Comment on lines +100 to +101
function OneHotArray(x::AbstractVector)
!_onehot_compatible(x) && error("Array is not onehot compatible")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What arrays do you want this to accept, which aren't AbstractVector{Bool}?

And could "not onehot compatible" explain a bit more what it means by compatible?

cart_inds = CartesianIndex.(_indices(x), CartesianIndices(_indices(x)))
return reshape(cart_inds, (1, size(_indices(x))...))
else
return argmax(BitArray(x); dims=dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is copying here better than invokeing the more generic method?

@@ -69,6 +87,30 @@ Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N,
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

_onehot_compatible(x::OneHotLike) = _isonehot(x)
_onehot_compatible(x::AbstractVector{Bool}) = count(x) == 1
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly tidier, maybe:

Suggested change
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1))
_onehot_compatible(x::AbstractArray{Bool}) = all(isone, count(x; dims=1))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants