From c616b47a8104736d2b533afb2a92da3abb8a4c40 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 29 Feb 2024 15:55:56 +0100 Subject: [PATCH 1/4] Add `PermuteDims` transformation --- src/DataAugmentation.jl | 1 + src/preprocessing.jl | 33 +++++++++++++++++++++++++++++++++ test/preprocessing.jl | 10 ++++++++++ 3 files changed, 44 insertions(+) diff --git a/src/DataAugmentation.jl b/src/DataAugmentation.jl index d2df407e..419dd418 100644 --- a/src/DataAugmentation.jl +++ b/src/DataAugmentation.jl @@ -56,6 +56,7 @@ export Item, Polygon, ToEltype, ImageToTensor, + PermuteDims, Normalize, NormalizeIntensity, MaskMulti, diff --git a/src/preprocessing.jl b/src/preprocessing.jl index c24e27fa..a976ce13 100644 --- a/src/preprocessing.jl +++ b/src/preprocessing.jl @@ -219,6 +219,39 @@ function _colorview(C::Type{<:Color}, img) where T return colorview(C, img) end +# ### [`PermuteDims`](#) + +""" + PermuteDims(perm) + +Permute the dimensions of an `ArrayItem`. +`perm` is a vector or a tuple of length `ndims(A)` specifying the permutation. + +Refer to the `permutedims` documentation for examples for permutation vectors `perm`. + +## Examples + +Preprocessing an image with 3 color channels. + +{cell=Normalize} +```julia +using DataAugmentation, Images +image = Image(rand(RGB, 20, 20)) +tfms = ImageToTensor() |> PermuteDims(2, 1, 3) # HWC to WHC +apply(tfms, image) +``` + +""" +struct PermuteDims{N} <: Transform + perm::NTuple{N, Int} +end +PermuteDims(perm...) = PermuteDims(perm) + +function apply(tfm::PermuteDims, item::ArrayItem; randstate = nothing) + data = permutedims(itemdata(item), tfm.perm) + return ArrayItem(data) +end + # OneHot encoding """ diff --git a/test/preprocessing.jl b/test/preprocessing.jl index 13f8d3dd..00ddf137 100644 --- a/test/preprocessing.jl +++ b/test/preprocessing.jl @@ -62,6 +62,16 @@ end end end +@testset ExtendedTestSet "PermuteDims" begin + tfm = PermuteDims(2, 1, 4, 3) + A = rand(3, 4, 5, 6) + item = ArrayItem(A) + @test_nowarn apply(tfm, item) + B = itemdata(apply(tfm, item)) + @test size(B) == (4, 3, 6, 5) + testapply(tfm, item) +end + @testset ExtendedTestSet "OneHot" begin tfm = OneHot() mask = rand(1:4, 10, 10) From 2c02a6e76b712dd52762a2c808a59cd42eba9a7f Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 29 Feb 2024 16:24:51 +0100 Subject: [PATCH 2/4] Support `apply!` --- src/preprocessing.jl | 7 +++++++ test/preprocessing.jl | 14 ++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/preprocessing.jl b/src/preprocessing.jl index a976ce13..4793faf9 100644 --- a/src/preprocessing.jl +++ b/src/preprocessing.jl @@ -229,6 +229,8 @@ Permute the dimensions of an `ArrayItem`. Refer to the `permutedims` documentation for examples for permutation vectors `perm`. +Supports `apply!`. + ## Examples Preprocessing an image with 3 color channels. @@ -252,6 +254,11 @@ function apply(tfm::PermuteDims, item::ArrayItem; randstate = nothing) return ArrayItem(data) end +function apply!(buf, tfm::PermuteDims, item::ArrayItem; randstate = nothing) + permutedims!(itemdata(buf), itemdata(item), tfm.perm) + return buf +end + # OneHot encoding """ diff --git a/test/preprocessing.jl b/test/preprocessing.jl index 00ddf137..5d8a43a4 100644 --- a/test/preprocessing.jl +++ b/test/preprocessing.jl @@ -64,12 +64,14 @@ end @testset ExtendedTestSet "PermuteDims" begin tfm = PermuteDims(2, 1, 4, 3) - A = rand(3, 4, 5, 6) - item = ArrayItem(A) - @test_nowarn apply(tfm, item) - B = itemdata(apply(tfm, item)) - @test size(B) == (4, 3, 6, 5) - testapply(tfm, item) + item1 = ArrayItem(rand(3, 4, 5, 6)) + item2 = ArrayItem(rand(3, 4, 5, 6)) + @test_nowarn apply(tfm, item1) + a = itemdata(apply(tfm, item1)) + @test size(a) == (4, 3, 6, 5) + + testapply(tfm, item1) + testapply!(tfm, item1, item2) end @testset ExtendedTestSet "OneHot" begin From 7df7582d41c75000c65daea92457d9cd1ba87d51 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 29 Feb 2024 16:32:34 +0100 Subject: [PATCH 3/4] Minor fixes to docstring --- src/preprocessing.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/preprocessing.jl b/src/preprocessing.jl index 4793faf9..0a981601 100644 --- a/src/preprocessing.jl +++ b/src/preprocessing.jl @@ -225,9 +225,9 @@ end PermuteDims(perm) Permute the dimensions of an `ArrayItem`. -`perm` is a vector or a tuple of length `ndims(A)` specifying the permutation. +`perm` is a vector or a tuple specifying the permutation, whose length has to match the dimensionality of the `ArrayItem`s data. -Refer to the `permutedims` documentation for examples for permutation vectors `perm`. +Refer to the `permutedims` documentation for examples of permutation vectors `perm`. Supports `apply!`. @@ -235,11 +235,14 @@ Supports `apply!`. Preprocessing an image with 3 color channels. -{cell=Normalize} +{cell=PermuteDims} ```julia using DataAugmentation, Images image = Image(rand(RGB, 20, 20)) -tfms = ImageToTensor() |> PermuteDims(2, 1, 3) # HWC to WHC + +# Turn image to tensor and permute dimensions 2 and 1 +# to convert HWC (height, width, channel) array to WHC (width, height, channel) +tfms = ImageToTensor() |> PermuteDims(2, 1, 3) apply(tfms, image) ``` From 3e44d1d150a27756e7802ef03a67e3b9ad52bab8 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Sat, 9 Mar 2024 21:12:37 +0100 Subject: [PATCH 4/4] Use `PermuteDimsArray` instead of `permutedims` --- src/preprocessing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/preprocessing.jl b/src/preprocessing.jl index 0a981601..83072663 100644 --- a/src/preprocessing.jl +++ b/src/preprocessing.jl @@ -253,7 +253,7 @@ end PermuteDims(perm...) = PermuteDims(perm) function apply(tfm::PermuteDims, item::ArrayItem; randstate = nothing) - data = permutedims(itemdata(item), tfm.perm) + data = PermutedDimsArray(itemdata(item), tfm.perm) return ArrayItem(data) end