diff --git a/traja/core.py b/traja/core.py new file mode 100644 index 00000000..ec274fe0 --- /dev/null +++ b/traja/core.py @@ -0,0 +1,5 @@ +import pandas as pd + +# Check whether pandas series is datetime or timedelta +def is_datetime_or_timedelta_dtype(series: pd.Series) -> bool: + return pd.api.types.is_datetime64_dtype(series) or pd.api.types.is_timedelta64_dtype(series) \ No newline at end of file diff --git a/traja/plotting.py b/traja/plotting.py index 391d75bc..bb9fc6e0 100644 --- a/traja/plotting.py +++ b/traja/plotting.py @@ -14,15 +14,14 @@ from matplotlib.axes import Axes from matplotlib.collections import PathCollection from matplotlib.figure import Figure -from mpl_toolkits.mplot3d import Axes3D from pandas.core.dtypes.common import ( - is_datetime_or_timedelta_dtype, is_datetime64_any_dtype, is_timedelta64_dtype, ) import traja from traja.frame import TrajaDataFrame +from traja.core import is_datetime_or_timedelta_dtype from traja.trajectory import coords_to_flow __all__ = [ diff --git a/traja/trajectory.py b/traja/trajectory.py index 51f672e1..26a4d320 100644 --- a/traja/trajectory.py +++ b/traja/trajectory.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from pandas.core.dtypes.common import ( - is_datetime_or_timedelta_dtype, is_datetime64_any_dtype, is_timedelta64_dtype, ) @@ -15,6 +14,7 @@ import traja from traja import TrajaDataFrame +from traja.core import is_datetime_or_timedelta_dtype __all__ = [ "_bins_to_tuple",