Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Oct 10, 2024
1 parent 919e003 commit d00cd93
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 60 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ module = [
"netCDF4.*",
"netcdftime.*",
"opt_einsum.*",
"pandas.*",
"pint.*",
"pooch.*",
"pyarrow.*",
Expand Down Expand Up @@ -178,7 +179,7 @@ module = [
"xarray.tests.test_units",
"xarray.tests.test_utils",
"xarray.tests.test_variable",
"xarray.tests.test_weighted",
"xarray.tests.test_weighted"
]

# Use strict = true whenever namedarray has become standalone. In the meantime
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/cftimeindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def contains(self, key: Any) -> bool:
"""Needed for .loc based partial-string indexing"""
return self.__contains__(key)

def shift( # type: ignore[override] # freq is typed Any, we are more precise
def shift( # freq is typed Any, we are more precise
self,
periods: int | float,
freq: str | timedelta | BaseCFTimeOffset | None = None,
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3032,7 +3032,7 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data
if not isinstance(idx, pd.MultiIndex):
raise ValueError(f"'{dim}' is not a stacked coordinate")

level_number = idx._get_level_number(level) # type: ignore[attr-defined]
level_number = idx._get_level_number(level)
variables = idx.levels[level_number]
variable_dim = idx.names[level_number]

Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6629,7 +6629,7 @@ def interpolate_na(
| None
) = None,
**kwargs: Any,
) -> Self:
) -> Dataset:
"""Fill in NaNs by interpolating according to different methods.
Parameters
Expand Down Expand Up @@ -6760,7 +6760,7 @@ def interpolate_na(
)
return new

def ffill(self, dim: Hashable, limit: int | None = None) -> Self:
def ffill(self, dim: Hashable, limit: int | None = None) -> Dataset:
"""Fill NaN values by propagating values forward
*Requires bottleneck.*
Expand Down Expand Up @@ -6824,7 +6824,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Self:
new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit)
return new

def bfill(self, dim: Hashable, limit: int | None = None) -> Self:
def bfill(self, dim: Hashable, limit: int | None = None) -> Dataset:
"""Fill NaN values by propagating values backward
*Requires bottleneck.*
Expand Down Expand Up @@ -7523,7 +7523,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:

if isinstance(idx, pd.MultiIndex):
dims = tuple(
name if name is not None else "level_%i" % n # type: ignore[redundant-expr]
name if name is not None else "level_%i" % n
for n, name in enumerate(idx.names)
)
for dim, lev in zip(dims, idx.levels, strict=True):
Expand Down Expand Up @@ -9829,7 +9829,7 @@ def eval(
c (x) float64 40B 0.0 1.25 2.5 3.75 5.0
"""

return pd.eval( # type: ignore[return-value]
return pd.eval(
statement,
resolvers=[self],
target=self,
Expand Down
10 changes: 6 additions & 4 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
def __extension_duck_array__concatenate(
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
) -> T_ExtensionArray:
return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined]
return type(arrays[0])._concat_same_type(arrays)


@implements(np.where)
Expand All @@ -57,8 +57,8 @@ def __extension_duck_array__where(
and isinstance(y, pd.Categorical)
and x.dtype != y.dtype
):
x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment]
y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment]
x = x.add_categories(set(y.categories).difference(set(x.categories)))
y = y.add_categories(set(x.categories).difference(set(y.categories)))
return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array)


Expand Down Expand Up @@ -116,7 +116,9 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
if is_extension_array_dtype(item):
return type(self)(item)
if np.isscalar(item):
return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed
return type(self)(
type(self.array)([item])
) # only subclasses with proper __init__ allowed
return item

def __setitem__(self, key, val):
Expand Down
12 changes: 7 additions & 5 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def isel(
# scalar indexer: drop index
return None

return self._replace(self.index[indxr]) # type: ignore[index]
return self._replace(self.index[indxr])

def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
Expand Down Expand Up @@ -926,7 +926,7 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex:
return cast(T_PDIndex, new_index)

if isinstance(index, pd.CategoricalIndex):
return index.remove_unused_categories() # type: ignore[attr-defined]
return index.remove_unused_categories()

return index

Expand Down Expand Up @@ -1164,7 +1164,7 @@ def create_variables(
dtype = None
else:
level = name
dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok?
dtype = self.level_coords_dtype[name] # TODO: are Hashables ok?

var = variables.get(name, None)
if var is not None:
Expand All @@ -1174,7 +1174,9 @@ def create_variables(
attrs = {}
encoding = {}

data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok?
data = PandasMultiIndexingAdapter(
self.index, dtype=dtype, level=level
) # TODO: are Hashables ok?
index_vars[name] = IndexVariable(
self.dim,
data,
Expand Down Expand Up @@ -1671,7 +1673,7 @@ def copy_indexes(
convert_new_idx = False
xr_idx = idx

new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment]
new_idx = xr_idx._copy(deep=deep, memo=memo)
idx_vars = xr_idx.create_variables(coords)

if convert_new_idx:
Expand Down
66 changes: 31 additions & 35 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.variable import IndexVariable


def _get_nan_block_lengths(
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
copy: bool = False,
bounds_error: bool = False,
order: Optional[int] = None,
axis=-1,
axis: int = -1,
**kwargs,
):
from scipy.interpolate import interp1d
Expand All @@ -167,8 +166,6 @@ def __init__(
self.cons_kwargs = kwargs
self.call_kwargs = {}

nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j

self.f = interp1d(
xi,
yi,
Expand All @@ -192,13 +189,13 @@ class SplineInterpolator(BaseInterpolator):

def __init__(
self,
xi,
yi,
method="spline",
fill_value=None,
order=3,
nu=0,
ext=None,
xi: Variable,
yi: np.ndarray,
method: Optional[str | int] = "spline",
fill_value: Optional[float | complex] = None,
order: int = 3,
nu: Optional[float] = 0,
ext: Optional[int | str] = None,
**kwargs,
):
from scipy.interpolate import UnivariateSpline
Expand All @@ -216,7 +213,9 @@ def __init__(
self.f = UnivariateSpline(xi, yi, k=order, **self.cons_kwargs)


def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):
def _apply_over_vars_with_dim(
func: Callable, self: Dataset, dim: Optional[Hashable] = None, **kwargs
) -> Dataset:
"""Wrapper for datasets"""
ds = type(self)(coords=self.coords, attrs=self.attrs)

Expand Down Expand Up @@ -606,7 +605,7 @@ def _floatize_x(x, new_x):

def interp(
var: Variable,
indexes_coords: dict[str, IndexVariable],
indexes_coords: dict[Hashable, tuple[Any, Any]],
method: InterpOptions,
**kwargs,
) -> Variable:
Expand Down Expand Up @@ -671,9 +670,9 @@ def interp(


def interp_func(
var: np.ndarray,
x: list[IndexVariable],
new_x: list[IndexVariable],
var: DataArray,
x: tuple[Variable, ...],
new_x: tuple[Variable, ...],
method: InterpOptions,
kwargs: dict,
) -> np.ndarray:
Expand All @@ -683,13 +682,10 @@ def interp_func(
Parameters
----------
var : np.ndarray or dask.array.Array
Array to be interpolated. The final dimension is interpolated.
x : a list of 1d array.
Original coordinates. Should not contain NaN.
new_x : a list of 1d array
New coordinates. Should not contain NaN.
method : string
var : Array to be interpolated. The final dimension is interpolated.
x : Original coordinates. Should not contain NaN.
new_x : New coordinates. Should not contain NaN.
method :
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima',
'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation.
{'linear', 'nearest'} for multidimensional interpolation
Expand All @@ -710,7 +706,7 @@ def interp_func(
scipy.interpolate.interp1d
"""
if not x:
return var.copy()
return var.data.copy()

if len(x) == 1:
func, kwargs = _get_interpolator(method, vectorizeable_only=True, **kwargs)
Expand All @@ -727,11 +723,11 @@ def interp_func(

# blockwise args format
x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)]
x_arginds = [item for pair in x_arginds for item in pair]
x_arginds = [item for pair in x_arginds for item in pair] # type: ignore[misc]
new_x_arginds = [
[_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x
]
new_x_arginds = [item for pair in new_x_arginds for item in pair]
new_x_arginds = [item for pair in new_x_arginds for item in pair] # type: ignore[misc]

args = (var, range(ndim), *x_arginds, *new_x_arginds)

Expand All @@ -741,13 +737,13 @@ def interp_func(
elem for pair in zip(rechunked, args[1::2], strict=True) for elem in pair
)

new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]
new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] # type: ignore[assignment]

new_x0_chunks = new_x[0].chunks
new_x0_shape = new_x[0].shape
new_x0_chunks_is_not_none = new_x0_chunks is not None
new_axes = {
ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i]
ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] # type: ignore[index]
for i in range(new_x[0].ndim)
}

Expand All @@ -757,7 +753,7 @@ def interp_func(
# scipy.interpolate.interp1d always forces to float.
# Use the same check for blockwise as well:
if not issubclass(var.dtype.type, np.inexact):
dtype = float
dtype = np.dtype(float)
else:
dtype = var.dtype

Expand All @@ -772,18 +768,18 @@ def interp_func(
localize=localize,
concatenate=True,
dtype=dtype,
new_axes=new_axes,
new_axes=new_axes, # type: ignore[arg-type]
meta=meta,
align_arrays=False,
)

return _interpnd(var, x, new_x, func, kwargs)
return _interpnd(var.data, x, new_x, func, kwargs)


def _interp1d(
var: np.ndarray,
x: IndexVariable,
new_x: IndexVariable,
x: Variable,
new_x: Variable,
func: Callable,
kwargs: dict,
) -> np.ndarray:
Expand All @@ -798,8 +794,8 @@ def _interp1d(

def _interpnd(
var: np.ndarray,
x: list[IndexVariable],
new_x: list[IndexVariable],
x: tuple[Variable, ...],
new_x: tuple[Variable, ...],
func: Callable,
kwargs: dict,
) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype:
if not is_valid_numpy_dtype(array.dtype):
return np.dtype("O")

return array.dtype # type: ignore[return-value]
return array.dtype


def maybe_coerce_to_str(index, original_coords):
Expand Down Expand Up @@ -180,7 +180,7 @@ def equivalent(first: T, second: T) -> bool:
return duck_array_ops.array_equiv(first, second)
if isinstance(first, list) or isinstance(second, list):
return list_equiv(first, second) # type: ignore[arg-type]
return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload]
return (first == second) or (pd.isnull(first) and pd.isnull(second))


def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def as_variable(
) from error
elif utils.is_scalar(obj):
obj = Variable([], obj)
elif isinstance(obj, pd.Index | IndexVariable) and obj.name is not None:
elif isinstance(obj, pd.Index | IndexVariable) and obj.name is not None: # type: ignore[redundant-expr]
obj = Variable(obj.name, obj)
elif isinstance(obj, set | dict):
raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}")
Expand Down
2 changes: 1 addition & 1 deletion xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:

data = np.asarray(group.data) # Cast _DummyGroup data to array

binned, self.bins = pd.cut( # type: ignore [call-overload]
binned, self.bins = pd.cut(
data.ravel(),
bins=self.bins,
right=self.right,
Expand Down
Loading

0 comments on commit d00cd93

Please sign in to comment.