Skip to content

Commit

Permalink
Migrate avif and heic decoders to torchvision-extra-decoders repo (py…
Browse files Browse the repository at this point in the history
…torch#8671)

Co-authored-by: atalman <atalman@fb.com>
  • Loading branch information
NicolasHug and atalman committed Dec 12, 2024
1 parent 2f3a43f commit 28942b4
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 433 deletions.
5 changes: 5 additions & 0 deletions .github/scripts/setup-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ echo '::group::Install TorchVision'
python setup.py develop
echo '::endgroup::'

echo '::group::Install torchvision-extra-decoders'
# This can be done after torchvision was built
pip install torchvision-extra-decoders
echo '::endgroup::'

echo '::group::Collect environment information'
conda list
python -m torch.utils.collect_env
Expand Down
15 changes: 9 additions & 6 deletions docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ images and videos.
Image Decoding
--------------

Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG
decoding can also be done on CUDA GPUs.
Torchvision currently supports decoding JPEG, PNG, WEBP, GIF, AVIF, and HEIC
images. JPEG decoding can also be done on CUDA GPUs.

The main entry point is the :func:`~torchvision.io.decode_image` function, which
you can use as an alternative to ``PIL.Image.open()``. It will decode images
Expand All @@ -30,9 +30,10 @@ run transforms/preproc natively on tensors.
:func:`~torchvision.io.decode_image` will automatically detect the image format,
and call the corresponding decoder. You can also use the lower-level
format-specific decoders which can be more powerful, e.g. if you want to
encode/decode JPEGs on CUDA.
and call the corresponding decoder (except for HEIC and AVIF images, see details
in :func:`~torchvision.io.decode_avif` and :func:`~torchvision.io.decode_heic`).
You can also use the lower-level format-specific decoders which can be more
powerful, e.g. if you want to encode/decode JPEGs on CUDA.

.. autosummary::
:toctree: generated/
Expand All @@ -41,8 +42,10 @@ encode/decode JPEGs on CUDA.
decode_image
decode_jpeg
encode_png
decode_gif
decode_webp
decode_avif
decode_heic
decode_gif

.. autosummary::
:toctree: generated/
Expand Down
2 changes: 2 additions & 0 deletions packaging/post_build_script.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
#!/bin/bash
LD_LIBRARY_PATH="/usr/local/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH" python packaging/wheel/relocate.py

pip install torchvision-extra-decoders
34 changes: 0 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default!
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
# Note: the GPU video decoding stuff used to be called "video codec", which
Expand Down Expand Up @@ -51,8 +49,6 @@
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_HEIC = }")
print(f"{USE_AVIF = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{USE_CPU_VIDEO_DECODER = }")
Expand Down Expand Up @@ -336,36 +332,6 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_HEIC:
heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h")
if heic_found:
print("Building torchvision with HEIC support")
print(f"{heic_include_dir = }")
print(f"{heic_library_dir = }")
if heic_include_dir is not None and heic_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(heic_include_dir)
library_dirs.append(heic_library_dir)
libraries.append("heif")
define_macros += [("HEIC_FOUND", 1)]
else:
warnings.warn("Building torchvision without HEIC support")

if USE_AVIF:
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if avif_found:
print("Building torchvision with AVIF support")
print(f"{avif_include_dir = }")
print(f"{avif_library_dir = }")
if avif_include_dir is not None and avif_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(avif_include_dir)
library_dirs.append(avif_library_dir)
libraries.append("avif")
define_macros += [("AVIF_FOUND", 1)]
else:
warnings.warn("Building torchvision without AVIF support")

if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

Expand Down
35 changes: 34 additions & 1 deletion test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torchvision
from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file
from torchvision.io import decode_avif, decode_heic, decode_image, decode_jpeg, read_file
from torchvision.models import resnet50, ResNet50_Weights


Expand All @@ -24,13 +24,46 @@ def smoke_test_torchvision_read_decode() -> None:
img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")

img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")

img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
if img_webp.shape != (3, 100, 100):
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")

if sys.platform == "linux":
pass
# TODO: Fix/uncomment below (the TODO below is mostly accurate but we're
# still observing some failures on some CUDA jobs. Most are working.)
# if torch.cuda.is_available():
# # TODO: For whatever reason this only passes on the runners that
# # support CUDA.
# # Strangely, on the CPU runners where this fails, the AVIF/HEIC
# # tests (ran with pytest) are passing. This is likely related to a
# # libcxx symbol thing, and the proper libstdc++.so get loaded only
# # with pytest? Ugh.
# img_avif = decode_avif(read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif")))
# if img_avif.shape != (3, 100, 100):
# raise RuntimeError(f"Unexpected shape of img_avif: {img_avif.shape}")

# img_heic = decode_heic(
# read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic"))
# )
# if img_heic.shape != (3, 100, 100):
# raise RuntimeError(f"Unexpected shape of img_heic: {img_heic.shape}")
else:
try:
decode_avif(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif"))
except RuntimeError as e:
assert "torchvision-extra-decoders" in str(e)

try:
decode_heic(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic"))
except RuntimeError as e:
assert "torchvision-extra-decoders" in str(e)


def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
Expand Down
98 changes: 33 additions & 65 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import re
import sys
from contextlib import nullcontext
from pathlib import Path

import numpy as np
Expand All @@ -14,11 +13,10 @@
import torchvision.transforms.v2.functional as F
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision._internally_replaced_utils import IN_FBCODE
from torchvision.io.image import (
_decode_avif,
_decode_heic,
decode_avif,
decode_gif,
decode_heic,
decode_image,
decode_jpeg,
decode_png,
Expand All @@ -43,22 +41,11 @@
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
IS_WINDOWS = sys.platform in ("win32", "cygwin")
IS_MACOS = sys.platform == "darwin"
IS_LINUX = sys.platform == "linux"
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")
# See https://github.com/pytorch/vision/pull/8724#issuecomment-2503964558
ROCM_WEBP_MESSAGE = "ROCM not built with webp support."

# Hacky way of figuring out whether we compiled with libavif/libheif (those are
# currenlty disabled by default)
try:
_decode_avif(torch.arange(10, dtype=torch.uint8))
except Exception as e:
DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e)

try:
_decode_heic(torch.arange(10, dtype=torch.uint8))
except Exception as e:
DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e)
HEIC_AVIF_MESSAGE = "AVIF and HEIF only available on linux."


def _get_safe_image_name(name):
Expand Down Expand Up @@ -866,19 +853,23 @@ def test_decode_gif(tmpdir, name, scripted):
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


decode_fun_and_match = [
(decode_png, "Content is not png"),
(decode_jpeg, "Not a JPEG file"),
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
(decode_webp, "WebPGetFeatures failed."),
]
if DECODE_AVIF_ENABLED:
decode_fun_and_match.append((_decode_avif, "BMFF parsing failed"))
if DECODE_HEIC_ENABLED:
decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box"))


@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match)
@pytest.mark.parametrize(
"decode_fun, match",
[
(decode_png, "Content is not png"),
(decode_jpeg, "Not a JPEG file"),
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
(decode_webp, "WebPGetFeatures failed."),
pytest.param(
decode_avif, "BMFF parsing failed", marks=pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
),
pytest.param(
decode_heic,
"Invalid input: No 'ftyp' box",
marks=pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE),
),
],
)
def test_decode_bad_encoded_data(decode_fun, match):
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
Expand Down Expand Up @@ -934,13 +925,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
@pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
@pytest.mark.parametrize("decode_fun", (decode_avif,))
def test_decode_avif(decode_fun):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
Expand All @@ -949,16 +937,8 @@ def test_decode_avif(decode_fun, scripted):

# Note: decode_image fails because some of these files have a (valid) signature
# we don't recognize. We should probably use libmagic....
decode_funs = []
if DECODE_AVIF_ENABLED:
decode_funs.append(_decode_avif)
if DECODE_HEIC_ENABLED:
decode_funs.append(_decode_heic)


@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.")
@pytest.mark.parametrize("decode_fun", decode_funs)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
@pytest.mark.parametrize("decode_fun", (decode_avif, decode_heic))
@pytest.mark.parametrize(
"mode, pil_mode",
(
Expand All @@ -970,7 +950,7 @@ def test_decode_avif(decode_fun, scripted):
@pytest.mark.parametrize(
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
)
def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename):
def test_decode_avif_heic_against_pil(decode_fun, mode, pil_mode, filename):
if "reversed_dimg_order" in str(filename):
# Pillow properly decodes this one, but we don't (order of parts of the
# image is wrong). This is due to a bug that was recently fixed in
Expand All @@ -980,8 +960,6 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file
import pillow_avif # noqa

encoded_bytes = read_file(filename)
if scripted:
decode_fun = torch.jit.script(decode_fun)
try:
img = decode_fun(encoded_bytes, mode=mode)
except RuntimeError as e:
Expand All @@ -994,6 +972,7 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file
"no 'ispe' property",
"'iref' has double references",
"Invalid image grid",
"decode_heif failed: Invalid input: No 'meta' box",
)
):
pytest.skip(reason="Expected failure, that's OK")
Expand All @@ -1010,7 +989,7 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file
try:
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
except RuntimeError as e:
if "Invalid image grid" in str(e):
if any(s in str(e) for s in ("Invalid image grid", "Failed to decode image: Not implemented")):
pytest.skip(reason="PIL failure")
else:
raise e
Expand All @@ -1021,7 +1000,7 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file
g = make_grid([img, from_pil])
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))

is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "decode_heic"
if mode == ImageReadMode.RGB and not is_decode_heic:
# We don't compare torchvision's AVIF against PIL for RGB because
# results look pretty different on RGBA images (other images are fine).
Expand All @@ -1035,13 +1014,10 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)


@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_heic(decode_fun, scripted):
@pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE)
@pytest.mark.parametrize("decode_fun", (decode_heic,))
def test_decode_heic(decode_fun):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
Expand Down Expand Up @@ -1080,13 +1056,5 @@ def test_mode_str():
assert decode_image(path, mode="RGBA").shape[0] == 4


def test_avif_heic_fbcode():
cm = nullcontext() if IN_FBCODE else pytest.raises(ImportError, match="cannot import")
with cm:
from torchvision.io import decode_heic # noqa
with cm:
from torchvision.io import decode_avif # noqa


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit 28942b4

Please sign in to comment.