Skip to content

Commit

Permalink
Merge pull request #89 from adrhill/ah/flip-hw
Browse files Browse the repository at this point in the history
Add `PermuteDims` transformation
  • Loading branch information
darsnack authored Mar 11, 2024
2 parents 0a1448b + 3e44d1d commit 8a39938
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/DataAugmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export Item,
Polygon,
ToEltype,
ImageToTensor,
PermuteDims,
Normalize,
NormalizeIntensity,
MaskMulti,
Expand Down
43 changes: 43 additions & 0 deletions src/preprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,49 @@ function _colorview(C::Type{<:Color}, img)
return colorview(C, img)
end

# ### [`PermuteDims`](#)

"""
PermuteDims(perm)
Permute the dimensions of an `ArrayItem`.
`perm` is a vector or a tuple specifying the permutation, whose length has to match the dimensionality of the `ArrayItem`s data.
Refer to the `permutedims` documentation for examples of permutation vectors `perm`.
Supports `apply!`.
## Examples
Preprocessing an image with 3 color channels.
{cell=PermuteDims}
```julia
using DataAugmentation, Images
image = Image(rand(RGB, 20, 20))
# Turn image to tensor and permute dimensions 2 and 1
# to convert HWC (height, width, channel) array to WHC (width, height, channel)
tfms = ImageToTensor() |> PermuteDims(2, 1, 3)
apply(tfms, image)
```
"""
struct PermuteDims{N} <: Transform
perm::NTuple{N, Int}
end
PermuteDims(perm...) = PermuteDims(perm)

function apply(tfm::PermuteDims, item::ArrayItem; randstate = nothing)
data = PermutedDimsArray(itemdata(item), tfm.perm)
return ArrayItem(data)
end

function apply!(buf, tfm::PermuteDims, item::ArrayItem; randstate = nothing)
permutedims!(itemdata(buf), itemdata(item), tfm.perm)
return buf
end

# OneHot encoding

"""
Expand Down
12 changes: 12 additions & 0 deletions test/preprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ end
end
end

@testset ExtendedTestSet "PermuteDims" begin
tfm = PermuteDims(2, 1, 4, 3)
item1 = ArrayItem(rand(3, 4, 5, 6))
item2 = ArrayItem(rand(3, 4, 5, 6))
@test_nowarn apply(tfm, item1)
a = itemdata(apply(tfm, item1))
@test size(a) == (4, 3, 6, 5)

testapply(tfm, item1)
testapply!(tfm, item1, item2)
end

@testset ExtendedTestSet "OneHot" begin
tfm = OneHot()
mask = rand(1:4, 10, 10)
Expand Down

0 comments on commit 8a39938

Please sign in to comment.