Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add similar and setindex! #6

Open
darsnack opened this issue Mar 5, 2022 · 4 comments
Open

Add similar and setindex! #6

darsnack opened this issue Mar 5, 2022 · 4 comments

Comments

@darsnack
Copy link
Member

darsnack commented Mar 5, 2022

See the original Flux issue and part of the solution here.

@TLipede
Copy link
Contributor

TLipede commented Mar 5, 2022

@darsnack what are your thoughts on how setindex! should behave? Are you thinking it should just be treated as a regular BitArray? Or should it effectively change the label for a particular data point?

Edit: I think I follow now actually from the thread

@darsnack
Copy link
Member Author

darsnack commented Mar 5, 2022

Yeah it should check that the one hot property is preserved or error. So if the first index isn't Colon then it should error. Otherwise check that the data itself is one hot. The easy case is Colon for the first index and the data is OneHotLike with the right shape.

@TLipede
Copy link
Contributor

TLipede commented Mar 8, 2022

I've been having a play around with this, and I have something that works as you'd expect for cases like hcat/vcat, but I'm not sure whether I have the right approach. I have something like the following:

function Base.similar(::OneHotArray{T, L}, ::Type{Bool}, dims::Dims) where {T, L}
  if first(dims) == L
    indices = ones(T, Base.tail(dims))
    return OneHotArray(indices, first(dims))
  else
    return BitArray(undef, dims)
  end
end

function Base.setindex!(x::OneHotArray{<:Any, <:Any, N}, v::Bool, i::Integer, I::Vararg{Integer, N}) where N
  @boundscheck checkbounds(x, i, I...)
  if v
    x.indices[I...] = i
  elseif x.indices[I...] == i
    x.indices[I...] = 0
  end
end

where I take the viewpoint that if dims[1] == L then you're probably trying to create a genuine OneHotArray (e.g. hcat, indexing with first index a Colon etc.). Anything else should return a BitArray.

The problem stems from the fact that not all Array{Bool}s are representable as OneHotArrays, but some algorithms may inadvertently try to do so. E.g. argmax(x::OneHotLike; dims=3) will call fill!(similar(x, reduced_inds), value), which will lead to errors or incorrect results (dependent on the approach to setindex!).

I think that these are the options on how to proceed, but maybe I'm missing an obviously better way; I'd be keen to hear your thoughts:

  1. Define a special case for when similar(x, dims::NTuple{Base.OneTo}) to deal with the argmin/argmax case. This feels prone to breakage if/when that changes, or if other packages do something similar.
  2. Define special implementations/force invocation of particular methods for various array functions in this package.
  3. Have similar always create a BitArray, along with e.g. a convert back to OneHotArray, erroring if necessary (was this what you meant by having one hot property preserved?). This would always work, but I guess would be slower/use memory.

What do you think?

@darsnack
Copy link
Member Author

I was thinking that setindex! with the first index as Colon should accept v that is a genuine one-hot. We can have a faster path when v<:OneHotLike and slower for Array{<:Bool}.

When the first index is not a Colon then the first if-statement in your implementation looks good. The second elseif-statement breaks with what we've allowed for one-hot arrays so far, so I would rather error in that situation.

I think your view on similar is the correct one (that we should only return a OneHotArray when the first dimension matches). Always returning a OneHotArray will break a lot of things that work for free but return Array{Bool}. This should get us a faster path if possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants