Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade type hint and others to Python 3.9 #8814

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset

T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
T1 = tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = tuple[Image.Image, Image.Image, Optional[np.ndarray]]


__all__ = (
Expand All @@ -37,8 +37,8 @@ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None
super().__init__(root=root)
self.transforms = transforms

self._flow_list: List[str] = []
self._image_list: List[List[str]] = []
self._flow_list: list[str] = []
self._image_list: list[list[str]] = []

def _read_img(self, file_name: str) -> Image.Image:
img = Image.open(file_name)
Expand Down Expand Up @@ -225,7 +225,7 @@ def __getitem__(self, index: int) -> Union[T1, T2]:
"""
return super().__getitem__(index)

def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
def _read_flow(self, file_name: str) -> tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name)


Expand Down Expand Up @@ -443,7 +443,7 @@ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Opt
"Could not find the HD1K images. Please make sure the directory structure is correct."
)

def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
def _read_flow(self, file_name: str) -> tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name)

def __getitem__(self, index: int) -> Union[T1, T2]:
Expand Down Expand Up @@ -479,7 +479,7 @@ def _read_flo(file_name: str) -> np.ndarray:
return data.reshape(h, w, 2).transpose(2, 0, 1)


def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> tuple[np.ndarray, np.ndarray]:

flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
Expand Down
36 changes: 18 additions & 18 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset

T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
T1 = tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
T2 = tuple[Image.Image, Image.Image, Optional[np.ndarray]]

__all__ = ()

Expand Down Expand Up @@ -65,11 +65,11 @@ def _scan_pairs(
self,
paths_left_pattern: str,
paths_right_pattern: Optional[str] = None,
) -> List[Tuple[str, Optional[str]]]:
) -> list[tuple[str, Optional[str]]]:

left_paths = list(sorted(glob(paths_left_pattern)))

right_paths: List[Union[None, str]]
right_paths: list[Union[None, str]]
if paths_right_pattern:
right_paths = list(sorted(glob(paths_right_pattern)))
else:
Expand All @@ -92,7 +92,7 @@ def _scan_pairs(
return paths

@abstractmethod
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
# function that returns a disparity map and an occlusion map
pass

Expand Down Expand Up @@ -178,7 +178,7 @@ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities = disparities

def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Opt
else:
self._disparities = list((None, None) for _ in self._images)

def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], None]:
# test split has no disparity maps
if file_path is None:
return None, None
Expand Down Expand Up @@ -345,7 +345,7 @@ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Opt
else:
self._disparities = list((None, None) for _ in self._images)

def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], None]:
# test split has no disparity maps
if file_path is None:
return None, None
Expand Down Expand Up @@ -549,7 +549,7 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
as the right image.
"""
ambient_file_paths: List[Union[str, Path]] # make mypy happy
ambient_file_paths: list[Union[str, Path]] # make mypy happy

if not isinstance(file_path, Path):
file_path = Path(file_path)
Expand All @@ -565,7 +565,7 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
file_path = random.choice(ambient_file_paths) # type: ignore
return super()._read_img(file_path)

def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
# test split has not disparity maps
if file_path is None:
return None, None
Expand Down Expand Up @@ -694,7 +694,7 @@ def __init__(
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities

def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 32.0
Expand Down Expand Up @@ -788,13 +788,13 @@ def __init__(self, root: Union[str, Path], variant: str = "single", transforms:
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
# (H, W) image
depth = np.asarray(Image.open(file_path))
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
# in order to extract disparity from depth maps
camera_settings_path = Path(file_path).parent / "_camera_settings.json"
with open(camera_settings_path, "r") as f:
with open(camera_settings_path) as f:
# inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
intrinsics = json.load(f)
focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
Expand Down Expand Up @@ -911,7 +911,7 @@ def __init__(
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
Expand Down Expand Up @@ -999,7 +999,7 @@ def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms:
disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
self._disparities += self._scan_pairs(disparity_pattern, None)

def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
def _get_occlussion_mask_paths(self, file_path: str) -> tuple[str, str]:
# helper function to get the occlusion mask paths
# a path will look like .../.../.../training/disparities/scene1/img1.png
# we want to get something like .../.../.../training/occlusions/scene1/img1.png
Expand All @@ -1020,7 +1020,7 @@ def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:

return occlusion_path, outofframe_path

def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
if file_path is None:
return None, None

Expand Down Expand Up @@ -1101,7 +1101,7 @@ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Opt
right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :] / 1024.0
Expand Down Expand Up @@ -1195,7 +1195,7 @@ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Opt
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
self._disparities = self._scan_pairs(disparity_pattern, None)

def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
# test split has no disparity maps
if file_path is None:
return None, None
Expand Down
10 changes: 5 additions & 5 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Caltech101(VisionDataset):
def __init__(
self,
root: Union[str, Path],
target_type: Union[List[str], str] = "category",
target_type: Union[list[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
Expand Down Expand Up @@ -71,14 +71,14 @@ def __init__(
}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

self.index: List[int] = []
self.index: list[int] = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""
Args:
index (int): Index
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index: List[int] = []
self.index: list[int] = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(
Expand All @@ -194,7 +194,7 @@ def __init__(
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""
Args:
index (int): Index
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self,
root: Union[str, Path],
split: str = "train",
target_type: Union[List[str], str] = "attr",
target_type: Union[list[str], str] = "attr",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
Expand Down Expand Up @@ -155,7 +155,7 @@ def download(self) -> None:

extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

target: Any = []
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _load_meta(self) -> None:
self.classes = data[self.meta["key"]]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""
Args:
index (int): Index
Expand Down
6 changes: 3 additions & 3 deletions torchvision/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
root: Union[str, Path],
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
target_type: Union[list[str], str] = "instance",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
self.images.append(os.path.join(img_dir, file_name))
self.targets.append(target_types)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""
Args:
index (int): Index
Expand Down Expand Up @@ -206,7 +206,7 @@ def extra_repr(self) -> str:
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
return "\n".join(lines).format(**self.__dict__)

def _load_json(self, path: str) -> Dict[str, Any]:
def _load_json(self, path: str) -> dict[str, Any]:
with open(path) as file:
data = json.load(file)
return data
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(

self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))

self._labels: List[Optional[int]]
self._labels: list[Optional[int]]
if self._split != "test":
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
content = json.load(file)
Expand All @@ -61,7 +61,7 @@ def __init__(
def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> tuple[Any, Any]:
image_file = self._image_files[idx]
label = self._labels[idx]

Expand Down
6 changes: 3 additions & 3 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def _load_image(self, id: int) -> Image.Image:
path = self.coco.loadImgs(id)[0]["file_name"]
return Image.open(os.path.join(self.root, path)).convert("RGB")

def _load_target(self, id: int) -> List[Any]:
def _load_target(self, id: int) -> list[Any]:
return self.coco.loadAnns(self.coco.getAnnIds(id))

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:

if not isinstance(index, int):
raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
Expand Down Expand Up @@ -105,5 +105,5 @@ class CocoCaptions(CocoDetection):

"""

def _load_target(self, id: int) -> List[str]:
def _load_target(self, id: int) -> list[str]:
return [ann["caption"] for ann in super()._load_target(id)]
2 changes: 1 addition & 1 deletion torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")

Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/fakedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class FakeData(VisionDataset):
def __init__(
self,
size: int = 1000,
image_size: Tuple[int, int, int] = (3, 224, 224),
image_size: tuple[int, int, int] = (3, 224, 224),
num_classes: int = 10,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand All @@ -37,7 +37,7 @@ def __init__(
self.image_size = image_size
self.random_offset = random_offset

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""
Args:
index (int): Index
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/fer2013.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_label(row):
else:
return None

with open(data_file, "r", newline="") as file:
with open(data_file, newline="") as file:
rows = (row for row in csv.DictReader(file))

if use_fer_file or use_icml_file:
Expand All @@ -104,7 +104,7 @@ def get_label(row):
def __len__(self) -> int:
return len(self._samples)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> tuple[Any, Any]:
image_tensor, target = self._samples[idx]
image = Image.fromarray(image_tensor.numpy())

Expand Down
Loading
Loading