Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add knn_kwargs parameter to TSNE API #265

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
33 changes: 30 additions & 3 deletions openTSNE/affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class PerplexityBasedNN(Affinities):
The number of neighbors to use in the kNN graph. If ``auto`` (default),
it is set to three times the perplexity.

knn_kwargs: Optional[None, dict]
Optional keyword arguments that will be passed to the ``knn_index``.

knn_index: Optional[nearest_neighbors.KNNIndex]
Optionally, a precomputed ``openTSNE.nearest_neighbors.KNNIndex`` object
can be specified. This option will ignore any KNN-related parameters.
Expand All @@ -150,6 +153,7 @@ def __init__(
random_state=None,
verbose=False,
k_neighbors="auto",
knn_kwargs=None,
knn_index=None,
):
# This can't work if neither data nor the knn index are specified
Expand Down Expand Up @@ -181,7 +185,7 @@ def __init__(

self.knn_index = get_knn_index(
data, method, _k_neighbors, metric, metric_params, n_jobs,
random_state, verbose
random_state, verbose, knn_index
)

else:
Expand All @@ -205,6 +209,7 @@ def __init__(
self.symmetrize = symmetrize
self.n_jobs = n_jobs
self.verbose = verbose
self.knn_kwargs = knn_kwargs

def set_perplexity(self, new_perplexity):
"""Change the perplexity of the affinity matrix.
Expand Down Expand Up @@ -352,7 +357,15 @@ def check_perplexity(perplexity, k_neighbors):


def get_knn_index(
data, method, k, metric, metric_params=None, n_jobs=1, random_state=None, verbose=False
data,
method,
k,
metric,
metric_params=None,
n_jobs=1,
random_state=None,
verbose=False,
knn_kwargs=None,
):
# If we're dealing with a precomputed distance matrix, our job is very easy,
# so we can skip all the remaining checks
Expand Down Expand Up @@ -394,6 +407,7 @@ def get_knn_index(
"of the supported methods or provide a valid `KNNIndex` instance." % method
)
else:
kwargs = dict() if knn_kwargs is None else knn_kwargs
knn_index = methods[method](
data=data,
k=k,
Expand All @@ -402,6 +416,7 @@ def get_knn_index(
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
**kwargs,
)

return knn_index
Expand Down Expand Up @@ -759,6 +774,9 @@ class MultiscaleMixture(Affinities):

verbose: bool

knn_kwargs: Optional[None, dict]
Optional keyword arguments that will be passed to the ``knn_index``.

knn_index: Optional[nearest_neighbors.KNNIndex]
Optionally, a precomptued ``openTSNE.nearest_neighbors.KNNIndex`` object
can be specified. This option will ignore any KNN-related parameters.
Expand All @@ -777,6 +795,7 @@ def __init__(
n_jobs=1,
random_state=None,
verbose=False,
knn_kwargs=None,
jnboehm marked this conversation as resolved.
Show resolved Hide resolved
knn_index=None,
):
# Perplexities must be specified, but has default set to none, so the
Expand Down Expand Up @@ -805,7 +824,15 @@ def __init__(
k_neighbors = min(n_samples - 1, int(3 * max_perplexity))

self.knn_index = get_knn_index(
data, method, k_neighbors, metric, metric_params, n_jobs, random_state, verbose
data,
method,
k_neighbors,
metric,
metric_params,
n_jobs,
random_state,
verbose,
knn_kwargs,
)

else:
Expand Down
24 changes: 18 additions & 6 deletions openTSNE/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ class Annoy(KNNIndex):
"taxicab",
]

def __init__(self, *args, **kwargs):
def __init__(self, *args, n_trees=50, **kwargs):
jnboehm marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
self.n_trees = n_trees

def build(self):
data, k = self.data, self.k
Expand Down Expand Up @@ -253,7 +254,7 @@ def build(self):
self.index.add_item(i, data[i])

# Number of trees. FIt-SNE uses 50 by default.
self.index.build(50, n_jobs=self.n_jobs)
self.index.build(self.n_trees, n_jobs=self.n_jobs)

# Return the nearest neighbors in the training set
distances = np.zeros((N, k))
Expand Down Expand Up @@ -417,7 +418,9 @@ class NNDescent(KNNIndex):
"yule",
]

def __init__(self, *args, **kwargs):
def __init__(
jnboehm marked this conversation as resolved.
Show resolved Hide resolved
self, *args, n_trees=None, n_iters=None, max_candidates=60, **kwargs
):
try:
import pynndescent # pylint: disable=unused-import,unused-variable
except ImportError:
Expand All @@ -426,6 +429,9 @@ def __init__(self, *args, **kwargs):
"pynndescent` or `pip install pynndescent`."
)
super().__init__(*args, **kwargs)
self.n_trees = n_trees
self.n_iters = n_iters
self.max_candidates = max_candidates

def check_metric(self, metric):
import pynndescent
Expand Down Expand Up @@ -470,8 +476,14 @@ def build(self):
timer.__enter__()

# These values were taken from UMAP, which we assume to be sensible defaults
n_trees = 5 + int(round((data.shape[0]) ** 0.5 / 20))
n_iters = max(5, int(round(np.log2(data.shape[0]))))
if self.n_trees is None:
jnboehm marked this conversation as resolved.
Show resolved Hide resolved
n_trees = 5 + int(round((data.shape[0]) ** 0.5 / 20))
else:
n_trees = self.n_trees
if self.n_iters is None:
n_iters = max(5, int(round(np.log2(data.shape[0]))))
else:
n_iters = self.n_iters

# Numba takes a while to load up, so there's little point in loading it
# unless we're actually going to use it
Expand All @@ -491,7 +503,7 @@ def build(self):
random_state=self.random_state,
n_trees=n_trees,
n_iters=n_iters,
max_candidates=60,
max_candidates=self.max_candidates,
n_jobs=self.n_jobs,
verbose=self.verbose > 1,
)
Expand Down
9 changes: 9 additions & 0 deletions openTSNE/tsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ class TSNEEmbedding(np.ndarray):
``ints_in_interval`` parameter. Higher values provide more accurate
gradient estimations.

knn_kwargs: Optional[None, dict]
Optional keyword arguments that will be passed to the ``knn_index``.

random_state: Union[int, RandomState]
The random state parameter follows the convention used in scikit-learn.
If the value is an int, random_state is the seed used by the random
Expand Down Expand Up @@ -1103,6 +1106,9 @@ class TSNE(BaseEstimator):
the given metric. Otherwise it uses Pynndescent. ``auto`` uses exact
nearest neighbors for N<1000 and the same heuristic as ``approx`` for N>=1000.

knn_kwargs: Optional[None, dict]
Optional keyword arguments that will be passed to the ``knn_index``.

negative_gradient_method: str
Specifies the negative gradient approximation method to use. For smaller
data sets, the Barnes-Hut approximation is appropriate and can be set
Expand Down Expand Up @@ -1152,6 +1158,7 @@ def __init__(
max_step_norm=5,
n_jobs=1,
neighbors="auto",
knn_kwargs=None,
negative_gradient_method="auto",
callbacks=None,
callbacks_every_iters=50,
Expand Down Expand Up @@ -1191,6 +1198,7 @@ def __init__(
self.n_jobs = n_jobs

self.neighbors = neighbors
self.knn_kwargs = knn_kwargs
self.negative_gradient_method = negative_gradient_method

self.callbacks = callbacks
Expand Down Expand Up @@ -1331,6 +1339,7 @@ def prepare_initial(self, X=None, affinities=None, initialization=None):
n_jobs=self.n_jobs,
random_state=self.random_state,
verbose=self.verbose,
knn_kwargs=self.knn_kwargs,
)
else:
if not isinstance(affinities, Affinities):
Expand Down
Loading