From d71d6c66f9882d212a92ec13c0399228fb0d19ce Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 19 May 2023 09:03:56 +0200 Subject: [PATCH 1/4] use buffer in BatchView --- src/batchview.jl | 23 ++++++++++++++--------- test.jl | 11 +++++++++++ test/batchview.jl | 27 +++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 9 deletions(-) create mode 100644 test.jl diff --git a/src/batchview.jl b/src/batchview.jl index 7778dae..b535d72 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -134,13 +134,8 @@ Base.@propagate_inbounds function getobs(A::BatchView) return _getbatch(A, 1:numobs(A.data)) end -Base.@propagate_inbounds function Base.getindex(A::BatchView, i::Int) - obsindices = _batchrange(A, i) - _getbatch(A, obsindices) -end - -Base.@propagate_inbounds function Base.getindex(A::BatchView, is::AbstractVector) - obsindices = union((_batchrange(A, i) for i in is)...)::Vector{Int} +Base.@propagate_inbounds function Base.getindex(A::BatchView, i) + obsindices = _batchindexes(A, i) _getbatch(A, obsindices) end @@ -154,6 +149,15 @@ function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where { getobs(A.data, obsindices) end +function getobs!(buffer, A::BatchView{TElem, TData, Val{nothing}}, i) where {TElem, TData} + obsindices = _batchindexes(A, i) + return _getbatch!(buffer, A, obsindices) +end + +function _getbatch!(buffer, A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData} + return getobs!(buffer, A.data, obsindices) +end + Base.parent(A::BatchView) = A.data Base.eltype(::BatchView{Tel}) where Tel = Tel @@ -169,6 +173,9 @@ Base.iterate(A::BatchView, state = 1) = return startidx:endidx end +@inline _batchindexes(A::BatchView, i::Integer) = _batchrange(A, i) +@inline _batchindexes(A::BatchView, is::AbstractVector{<:Integer}) = union((_batchrange(A, i) for i in is)...)::Vector{Int} + function Base.showarg(io::IO, A::BatchView, toplevel) print(io, "BatchView(") Base.showarg(io, parent(A), false) @@ -178,5 +185,3 @@ function Base.showarg(io::IO, A::BatchView, toplevel) print(io, ')') toplevel && print(io, " with eltype ", nameof(eltype(A))) # simplify end - -# -------------------------------------------------------------------- diff --git a/test.jl b/test.jl new file mode 100644 index 0000000..4d96dc1 --- /dev/null +++ b/test.jl @@ -0,0 +1,11 @@ +using MLUtils +struct DummyData{X} + x::X +end +MLUtils.numobs(data::DummyData) = numobs(data.x) +MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx) +MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx) + +data = DummyData(rand(3,100)) +d1 = collect(DataLoader(data; batchsize=1, buffer=true)) # no error +d2 = collect(DataLoader(data; batchsize=-1, buffer=true)) # error diff --git a/test/batchview.jl b/test/batchview.jl index 39fbdcf..5640d72 100644 --- a/test/batchview.jl +++ b/test/batchview.jl @@ -116,4 +116,31 @@ using MLUtils: obsview @test bv[2] == 6:10 @test_throws BoundsError bv[3] end + + + @testset "getobs!" begin + buf1 = rand(4, 3) + bv = BatchView(X, batchsize=3) + @test @inferred(getobs!(buf1, bv, 2)) === buf1 + @test buf1 == getobs(bv, 2) + + buf2 = rand(4, 6) + @test @inferred(getobs!(buf2, bv, [1,3])) === buf2 + @test buf2 == getobs(bv, [1,3]) + + @testset "custom type" begin # issue #156 + struct DummyData{X} + x::X + end + MLUtils.numobs(data::DummyData) = numobs(data.x) + MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx) + MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx) + + data = DummyData(X) + buf = rand(4, 3) + bv = BatchView(data, batchsize=3) + @test @inferred(getobs!(buf, bv, 2)) === buf + @test buf == getobs(bv, 2) + end + end end From 797a0e5a9cf0ea041999075922561d18f52d0f9b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 19 May 2023 09:04:12 +0200 Subject: [PATCH 2/4] cleanup --- test.jl | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 test.jl diff --git a/test.jl b/test.jl deleted file mode 100644 index 4d96dc1..0000000 --- a/test.jl +++ /dev/null @@ -1,11 +0,0 @@ -using MLUtils -struct DummyData{X} - x::X -end -MLUtils.numobs(data::DummyData) = numobs(data.x) -MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx) -MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx) - -data = DummyData(rand(3,100)) -d1 = collect(DataLoader(data; batchsize=1, buffer=true)) # no error -d2 = collect(DataLoader(data; batchsize=-1, buffer=true)) # error From deb63ef220d7d37b38afc6fa123cd25195f13f44 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 26 Jun 2024 13:38:10 +0200 Subject: [PATCH 3/4] remove specialization --- src/batchview.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/batchview.jl b/src/batchview.jl index b535d72..632a08b 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -139,22 +139,22 @@ Base.@propagate_inbounds function Base.getindex(A::BatchView, i) _getbatch(A, obsindices) end -function _getbatch(A::BatchView{TElem, TData, Val{true}}, obsindices) where {TElem, TData} +function _getbatch(A::BatchView{<:Any, <:Any, Val{true}}, obsindices) batch([getobs(A.data, i) for i in obsindices]) end -function _getbatch(A::BatchView{TElem, TData, Val{false}}, obsindices) where {TElem, TData} +function _getbatch(A::BatchView{<:Any, <:Any, Val{false}}, obsindices) return [getobs(A.data, i) for i in obsindices] end -function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData} +function _getbatch(A::BatchView{<:Any, <:Any, Val{nothing}}, obsindices) getobs(A.data, obsindices) end -function getobs!(buffer, A::BatchView{TElem, TData, Val{nothing}}, i) where {TElem, TData} +function getobs!(buffer, A::BatchView{<:Any, <:Any, Val{nothing}}, i) obsindices = _batchindexes(A, i) return _getbatch!(buffer, A, obsindices) end -function _getbatch!(buffer, A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData} +function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{nothing}}, obsindices) return getobs!(buffer, A.data, obsindices) end From 6ae585e72176370f52a2ff0163a77e13ec074072 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 26 Jun 2024 14:14:55 +0200 Subject: [PATCH 4/4] handle collate=nothing and collate=true --- src/batchview.jl | 17 ++++++++++++++++- test/batchview.jl | 5 +++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/batchview.jl b/src/batchview.jl index 632a08b..f5c1bad 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -149,7 +149,7 @@ function _getbatch(A::BatchView{<:Any, <:Any, Val{nothing}}, obsindices) getobs(A.data, obsindices) end -function getobs!(buffer, A::BatchView{<:Any, <:Any, Val{nothing}}, i) +function getobs!(buffer, A::BatchView, i) obsindices = _batchindexes(A, i) return _getbatch!(buffer, A, obsindices) end @@ -158,6 +158,21 @@ function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{nothing}}, obsindices return getobs!(buffer, A.data, obsindices) end +# This collate=true specialization doesn't seem to be particularly useful, use collate=nothing instead. +function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{true}}, obsindices) + for (i, idx) in enumerate(obsindices) + getobs!(buffer[i], A.data, idx) + end + return batch(buffer) +end + +function _getbatch!(buffer, A::BatchView{<:Any, <:Any, Val{false}}, obsindices) + for (i, idx) in enumerate(obsindices) + getobs!(buffer[i], A.data, idx) + end + return buffer +end + Base.parent(A::BatchView) = A.data Base.eltype(::BatchView{Tel}) where Tel = Tel diff --git a/test/batchview.jl b/test/batchview.jl index 5640d72..f4ca79d 100644 --- a/test/batchview.jl +++ b/test/batchview.jl @@ -123,6 +123,11 @@ using MLUtils: obsview bv = BatchView(X, batchsize=3) @test @inferred(getobs!(buf1, bv, 2)) === buf1 @test buf1 == getobs(bv, 2) + + buf12 = [rand(4) for _=1:3] + bv12 = BatchView(X, batchsize=3, collate=false) + @test @inferred(getobs!(buf12, bv12, 2)) === buf12 + @test buf12 == getobs(bv12, 2) buf2 = rand(4, 6) @test @inferred(getobs!(buf2, bv, [1,3])) === buf2