diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 8454231..95b42bd 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -22,7 +22,7 @@ export mapobs, groupobs, joinobs, shuffleobs - + include("batchview.jl") export batchsize, BatchView diff --git a/src/batchview.jl b/src/batchview.jl index 90c27dc..c1cd63a 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -100,7 +100,7 @@ Return the fixed size of each batch in `data`. """ batchsize(A::BatchView) = A.batchsize -numobs(A::BatchView) = A.count +Base.length(A::BatchView) = A.count getobs(A::BatchView) = getobs(A.data) getobs(A::BatchView, i::Int) = getobs(A.data, _batchrange(A, i)) @@ -119,6 +119,10 @@ function Base.getindex(A::BatchView, is::AbstractVector) obsview(A.data, obsindices) end +# override AbstractDataContainer default +Base.iterate(A::BatchView, state = 1) = + (state > numobs(A)) ? nothing : (A[state], state + 1) + obsview(A::BatchView) = A obsview(A::BatchView, i) = A[i] diff --git a/src/observation.jl b/src/observation.jl index ff4d7fd..f831a56 100644 --- a/src/observation.jl +++ b/src/observation.jl @@ -2,8 +2,14 @@ numobs(data) Return the total number of observations contained in `data`. + If `data` does not have `numobs` defined, then this function falls back to `length(data)`. +Authors of custom data containers should implement +`Base.length` for their type instead of `numobs`. +`numobs` should only be implemented for types where there is a +difference between `numobs` and `Base.length` +(such as multi-dimensional arrays). See also [`getobs`](@ref) """ @@ -18,16 +24,20 @@ numobs(data) = length(data) Return the observations corresponding to the observation-index `idx`. Note that `idx` can be any type as long as `data` has defined `getobs` for that type. + If `data` does not have `getobs` defined, then this function falls back to `data[idx]`. +Authors of custom data containers should implement +`Base.getindex` for their type instead of `getobs`. +`getobs` should only be implemented for types where there is a +difference between `getobs` and `Base.getindex` +(such as multi-dimensional arrays). The returned observation(s) should be in the form intended to be passed as-is to some learning algorithm. There is no strict interface requirement on how this "actual data" must look like. - Every author behind some custom data container can make this decision themselves. - The output should be consistent when `idx` is a scalar vs vector. See also [`getobs!`](@ref) and [`numobs`](@ref) @@ -64,13 +74,10 @@ getobs!(buffer, data, idx) = getobs(data, idx) abstract type AbstractDataContainer end -Base.getindex(x::AbstractDataContainer, i) = getobs(x, i) -Base.length(x::AbstractDataContainer) = numobs(x) -Base.size(x::AbstractDataContainer) = (length(x),) - +Base.size(x::AbstractDataContainer) = (numobs(x),) Base.iterate(x::AbstractDataContainer, state = 1) = - (state > length(x)) ? nothing : (x[state], state + 1) -Base.lastindex(x::AbstractDataContainer) = length(x) + (state > numobs(x)) ? nothing : (getobs(x, state), state + 1) +Base.lastindex(x::AbstractDataContainer) = numobs(x) # -------------------------------------------------------------------- # Arrays diff --git a/src/obstransform.jl b/src/obstransform.jl index a5630f5..24f521e 100644 --- a/src/obstransform.jl +++ b/src/obstransform.jl @@ -1,7 +1,7 @@ # mapobs -struct MappedData{F,D} +struct MappedData{F,D} <: AbstractDataContainer f::F data::D end @@ -9,9 +9,9 @@ end Base.show(io::IO, data::MappedData) = print(io, "mapobs($(data.f), $(summary(data.data)))") Base.show(io::IO, data::MappedData{F,<:AbstractArray}) where {F} = print(io, "mapobs($(data.f), $(ShowLimit(data.data, limit=80)))") -numobs(data::MappedData) = numobs(data.data) -getobs(data::MappedData, idx::Int) = data.f(getobs(data.data, idx)) -getobs(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs)) +Base.length(data::MappedData) = numobs(data.data) +Base.getindex(data::MappedData, idx::Int) = data.f(getobs(data.data, idx)) +Base.getindex(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs)) """ @@ -38,14 +38,14 @@ Returns a tuple of transformed data containers. mapobs(fs::Tuple, data) = Tuple(mapobs(f, data) for f in fs) -struct NamedTupleData{TData,F} +struct NamedTupleData{TData,F} <: AbstractDataContainer data::TData namedfs::NamedTuple{F} end -numobs(data::NamedTupleData) = numobs(getfield(data, :data)) +Base.length(data::NamedTupleData) = numobs(getfield(data, :data)) -function getobs(data::NamedTupleData{TData,F}, idx::Int) where {TData,F} +function Base.getindex(data::NamedTupleData{TData,F}, idx::Int) where {TData,F} obs = getobs(getfield(data, :data), idx) namedfs = getfield(data, :namedfs) return NamedTuple{F}(f(obs) for f in namedfs) @@ -126,16 +126,16 @@ end # joinumobs -struct JoinedData{T,N} +struct JoinedData{T,N} <: AbstractDataContainer datas::NTuple{N,T} ns::NTuple{N,Int} end JoinedData(datas) = JoinedData(datas, numobs.(datas)) -numobs(data::JoinedData) = sum(data.ns) +Base.length(data::JoinedData) = sum(data.ns) -function getobs(data::JoinedData, idx) +function Base.getindex(data::JoinedData, idx) for (i, n) in enumerate(data.ns) if idx <= n return getobs(data.datas[i], idx) diff --git a/src/obsview.jl b/src/obsview.jl index 0127403..a8ba1af 100644 --- a/src/obsview.jl +++ b/src/obsview.jl @@ -178,11 +178,10 @@ end Base.IteratorEltype(::Type{<:ObsView}) = Base.EltypeUnknown() -# override AbstractDataContainer defaults Base.getindex(subset::ObsView, idx) = obsview(subset.data, subset.indices[idx]) -numobs(subset::ObsView) = length(subset.indices) +Base.length(subset::ObsView) = length(subset.indices) getobs(subset::ObsView) = getobs(subset.data, subset.indices) getobs(subset::ObsView, idx) = getobs(subset.data, subset.indices[idx])