Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-trained model tutorials don't permute image height and width channels #275

Open
adrhill opened this issue Feb 28, 2024 · 7 comments
Open

Comments

@adrhill
Copy link
Contributor

adrhill commented Feb 28, 2024

The doc page on Working with pre-trained models from Metalhead uses DataAugmentations to crop and normalize an image, however DataAugmentations.jl creates arrays of shape HWC instead of the WHCN format required by Flux.jl.

Related issue: FluxML/DataAugmentation.jl#56 (comment)

@adrhill
Copy link
Contributor Author

adrhill commented Feb 28, 2024

As a quick sanity check, I copied the example from the tutorial and cast the first channel to Gray. Since the Images.jl ecosystem uses HW matrices of Colorants, the cropped image is oriented "the right way", showing that the code outputs HWC data.

using Images, ImageInTerminal
using DataAugmentation
using Flux
using Flux: onecold

img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg"));

DATA_MEAN = (0.485, 0.456, 0.406)
DATA_STD = (0.229, 0.224, 0.225)

augmentations = CenterCrop((224, 224)) |>
                ImageToTensor() |>
                Normalize(DATA_MEAN, DATA_STD);

data = apply(augmentations, Image(img)) |> itemdata;
Gray.(data[:, :, 1])
image

@ToucheSir
Copy link
Member

This is a tricky one because interop requirements pull us in conflicting directions. You've identified how Julia image libraries expect HWC. We use WHCN because underlying libraries like cuDNN use NCHW. Technically, I'm not aware of any routines in said libraries which would behave differently if you swapped the height and weight dims, so that can be worked around.

What can't be worked around is interop on pretrained weights. Those come from Python libraries, and Python libraries are using NCHW or NWHC. We might be able to do some transformations to make them work for NCWH/HWNC, but without spending a bunch of time to look into that I don't want to say for sure.

@adrhill
Copy link
Contributor Author

adrhill commented Feb 29, 2024

What can't be worked around is interop on pretrained weights. Those come from Python libraries, and Python libraries are using NCHW or NWHC. We might be able to do some transformations to make them work for NCWH/HWNC, but without spending a bunch of time to look into that I don't want to say for sure.

This confused me a bit. I thought all Metalhead models expect WHCN. Is this not the case?

@ToucheSir
Copy link
Member

ToucheSir commented Feb 29, 2024

To clarify, data stored as (N)CHW in Python libraries can be read in as WHC(N) in Julia without any permutations or copying. It's a simple row- to column-major swap. Reading data in as HWC(N) as required by the image libraries would require copying and/or permutation, which is why we expect WHCN.

In practice, some models may be able to work with HWCN because none of the operators they use distinguish between height and width. For example, if you didn't load pre-trained weights and used a model with only square conv kernels/pooling reductions/up- and down-sampling.

@adrhill
Copy link
Contributor Author

adrhill commented Feb 29, 2024

Ok, then maybe I didn't clarify my issue well enough:

data, the output of the DataAugmentation pipeline in the documentation, is in HWC format and fed into a model that expects WHCN inputs. I would consider this a "bug" in the documentation. Users copying this example would use inputs with permuted dimensions and most likely get worse predictions.

This could be addressed by adding a call to permutedims in the example.
I've also opened a PR for a matching transformation in DataAugmentations.jl: FluxML/DataAugmentation.jl#89

@adrhill
Copy link
Contributor Author

adrhill commented Feb 29, 2024

Alternatively, an argument could be made that DataAugmentation.jl's ImageToTensor should return WHC arrays instead of the current HWC.

using DataAugmentation, Images

h, w = 20, 30
im = rand(RGB, h, w)

item = Image(im)
tfm = ImageToTensor()
apply(tfm, item) |> itemdata |> size
image

@ToucheSir
Copy link
Member

Thanks, it wasn't clear to me whether the "bugfix" being proposed was simply a docs tweak or something larger.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants