Skip to content

Commit

Permalink
Merge pull request #90 from adrhill/ah/whc
Browse files Browse the repository at this point in the history
Make `ImageToTensor` compatible with Flux's `WHC` format
  • Loading branch information
CarloLucibello authored Mar 20, 2024
2 parents 8a39938 + fd3573c commit 38e4be5
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 33 deletions.
58 changes: 58 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,64 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.3.0]

### Added

- Add `PermuteDims` transformation

### Changed

- `ImageToTensor` now returns arrays in `WHC` format instead of `HWC`

## [0.2.12]

### Added

- Added optional `clamp` flag to `AdjustBrightness` and `AdjustContrast`

### Changed

- Transfered repository to FluxML

## [0.2.11]

### Fixed

- Bump Setfield compat

## [0.2.10]

### Fixed

- Fix `RandomCrop` transform

## [0.2.9]

### Added

- Set up Pollen.jl documentation

### Fixed

- Fix deprecated call to `warp`

## [0.2.8]

### Fixed

- Fix compatibility with ImageTransformations 0.9

## [0.2.7]

### Changed

- Replace Images.jl dependency with ImageCore.jl

### Fixed

- Fix crop-projective composition

## [0.2.6]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DataAugmentation"
uuid = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
authors = ["lorenzoh <lorenz.ohly@gmail.com>"]
version = "0.2.12"
version = "0.3.0"

[deps]
ColorBlendModes = "60508b50-96e1-4007-9d6c-f475c410f16b"
Expand Down
49 changes: 18 additions & 31 deletions src/preprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ end
"""
ImageToTensor()
Expands an `Image{N, T}` of size `sz` to an `ArrayItem{N+1}` with
size `(sz..., ch)` where `ch` is the number of color channels of `T`.
Expands an `Image{N, T}` of size `(height, width, ...)` to an `ArrayItem{N+1}` with
size `(width, height, ..., ch)` where `ch` is the number of color channels of `T`.
Supports `apply!`.
Expand All @@ -144,9 +144,10 @@ Supports `apply!`.
```julia
using DataAugmentation, Images
image = Image(rand(RGB, 50, 50))
h, w = 40, 50
image = Image(rand(RGB, h, w))
tfm = ImageToTensor()
apply(tfm, image)
apply(tfm, image) # ArrayItem in WHC format of size (50, 40, 3)
```
"""
Expand All @@ -165,24 +166,12 @@ function apply!(buf, ::ImageToTensor, image::Image; randstate = nothing)
return buf
end

function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {C<:Color, N}
T.(PermutedDimsArray(_channelview(image), ((i for i in 2:N+1)..., 1)))
function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {C<:Colorant, N}
T.(PermutedDimsArray(_channelview(image), (3, 2, 4:N+1..., 1)))
end

#=
function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {TC, C<:Color{TC, 1}, N}
return T.(_channelview(image))
end
=#


# TODO: relax color type constraint, implement for other colors
# single-channel colors need a `channelview` that also expands the array
function imagetotensor!(buf, image::AbstractArray{<:Color, N}) where N
permutedims!(
buf,
_channelview(image),
(2:N+1..., 1))
function imagetotensor!(buf, image::AbstractArray{<:Colorant, N}) where N
permutedims!(buf, _channelview(image), (3, 2, 4:N+1..., 1))
end

function tensortoimage(a::AbstractArray)
Expand All @@ -197,22 +186,20 @@ function tensortoimage(a::AbstractArray)
end
end

function tensortoimage(C::Type{<:Color}, a::AbstractArray{T, N}) where {T, N}
perm = (N, 1:N-1...)
function tensortoimage(C::Type{<:Colorant}, a::AbstractArray{T, N}) where {T, N}
perm = (N, 2, 1, 3:N-1...)
return _colorview(C, PermutedDimsArray(a, perm))
end


function _channelview(img)
chview = channelview(img)
# for single-channel colors, expand the color dimension anyway
if size(img) == size(chview)
chview = reshape(chview, 1, size(chview)...)
end
return chview
# For single-channel colors, expand the color dimension anyway
# such that the output is always of size (channels, height, width, ...)
_channelview(img::AbstractArray{<:Colorant{T, N}}) where {T, N} = channelview(img)
function _channelview(img::AbstractArray{<:Colorant{T, 1}}) where T
cv = channelview(img)
return reshape(cv, 1, size(cv)...)
end

function _colorview(C::Type{<:Color}, img)
function _colorview(C::Type{<:Colorant}, img)
if size(img, 1) == 1
img = reshape(img, size(img)[2:end])
end
Expand Down
2 changes: 1 addition & 1 deletion test/preprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end

res = apply(tfm, item1)
a = itemdata(res)
@test size(a) == (32, 48, 3)
@test size(a) == (48, 32, 3)
@test eltype(a) == Float32

testapply(tfm, item1)
Expand Down

0 comments on commit 38e4be5

Please sign in to comment.