Skip to content

Commit

Permalink
Merge pull request #168 from yzhao062/development
Browse files Browse the repository at this point in the history
V0.7.8 Bug Fixes and New Models (VAE and LODA)
  • Loading branch information
yzhao062 authored Mar 17, 2020
2 parents 10a29b9 + aa47afd commit 84bad9f
Show file tree
Hide file tree
Showing 25 changed files with 910 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ v<0.7.7>, <12/21/2019> -- Extended combination methods by median and majority vo
v<0.7.7>, <12/22/2019> -- Code optimization and documentation update.
v<0.7.7>, <12/22/2019> -- Enable continuous integration for Python 3.7.
v<0.7.7.1>, <12/29/2019> -- Minor update for SUOD and warning fixes.
v<0.7.8>, <01/05/2019> -- Documentation update.
v<0.7.8>, <01/30/2019> -- Bug fix for kNN (#158).
v<0.7.8>, <03/14/2020> -- Add VAE (implemented by Dr Andrij Vasylenko).
v<0.7.8>, <03/17/2020> -- Add LODA (adapted from tilitools).



Expand Down
8 changes: 8 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ Outlier Ensembles IForest Isolation Forest
Outlier Ensembles Feature Bagging 2005 [#Lazarevic2005Feature]_
Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 [#Zhao2019LSCP]_
Outlier Ensembles XGBOD Extreme Boosting Based Outlier Detection **(Supervised)** 2018 [#Zhao2018XGBOD]_
Outlier Ensembles LODA Lightweight On-line Detector of Anomalies 2016 [#Pevny2016Loda]_
Neural Networks AutoEncoder Fully connected AutoEncoder (use reconstruction error as the outlier score) [#Aggarwal2015Outlier]_ [Ch.3]
Neural Networks VAE Variational AutoEncoder (use reconstruction error as the outlier score) 2013 [#Kingma2013Auto]_
Neural Networks SO_GAAL Single-Objective Generative Adversarial Active Learning 2019 [#Liu2019Generative]_
Neural Networks MO_GAAL Multiple-Objective Generative Adversarial Active Learning 2019 [#Liu2019Generative]_
=================== ================ ====================================================================================================== ===== ========================================
Expand All @@ -310,6 +312,8 @@ Type Abbr Algorithm
=================== ================ ===================================================================================================== ===== ========================================
Outlier Ensembles Feature Bagging 2005 [#Lazarevic2005Feature]_
Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 [#Zhao2019LSCP]_
Outlier Ensembles XGBOD Extreme Boosting Based Outlier Detection **(Supervised)** 2018 [#Zhao2018XGBOD]_
Outlier Ensembles LODA Lightweight On-line Detector of Anomalies 2016 [#Pevny2016Loda]_
Combination Average Simple combination by averaging the scores 2015 [#Aggarwal2015Theoretical]_
Combination Weighted Average Simple combination by averaging the scores with detector weights 2015 [#Aggarwal2015Theoretical]_
Combination Maximization Simple combination by taking the maximum scores 2015 [#Aggarwal2015Theoretical]_
Expand Down Expand Up @@ -592,6 +596,8 @@ Reference
.. [#Janssens2012Stochastic] Janssens, J.H.M., Huszár, F., Postma, E.O. and van den Herik, H.J., 2012. Stochastic outlier selection. Technical report TiCC TR 2012-001, Tilburg University, Tilburg Center for Cognition and Communication, Tilburg, The Netherlands.
.. [#Kingma2013Auto] Kingma, D.P. and Welling, M., 2013. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114.
.. [#Kriegel2008Angle] Kriegel, H.P. and Zimek, A., 2008, August. Angle-based outlier detection in high-dimensional data. In *KDD '08*\ , pp. 444-452. ACM.
.. [#Kriegel2009Outlier] Kriegel, H.P., Kröger, P., Schubert, E. and Zimek, A., 2009, April. Outlier detection in axis-parallel subspaces of high dimensional data. In *Pacific-Asia Conference on Knowledge Discovery and Data Mining*\ , pp. 831-838. Springer, Berlin, Heidelberg.
Expand All @@ -606,6 +612,8 @@ Reference
.. [#Papadimitriou2003LOCI] Papadimitriou, S., Kitagawa, H., Gibbons, P.B. and Faloutsos, C., 2003, March. LOCI: Fast outlier detection using the local correlation integral. In *ICDE '03*, pp. 315-326. IEEE.
.. [#Pevny2016Loda] Pevný, T., 2016. Loda: Lightweight on-line detector of anomalies. *Machine Learning*, 102(2), pp.275-304.
.. [#Ramaswamy2000Efficient] Ramaswamy, S., Rastogi, R. and Shim, K., 2000, May. Efficient algorithms for mining outliers from large data sets. *ACM Sigmod Record*\ , 29(2), pp. 427-438.
.. [#Rousseeuw1999A] Rousseeuw, P.J. and Driessen, K.V., 1999. A fast algorithm for the minimum covariance determinant estimator. *Technometrics*\ , 41(3), pp.212-223.
Expand Down
4 changes: 4 additions & 0 deletions docs/about.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ Yahya Almardeny (Software Systems & Machine Learning Engineer @ TSSG):
- Joined in 2019
- `LinkedIn (Yahya Almardeny) <https://www.linkedin.com/in/yahya-almardeny/>`_

Dr Andrij Vasylenko (Research Associate @ University of Liverpool)

- Joined in 2020 (implemented the VAE model)
- `Homepage (Dr Andrij Vasylenko) <https://www.liverpool.ac.uk/chemistry/staff/andrij-vasylenko/>`_
4 changes: 4 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ Outlier Ensembles IForest Isolation Forest
Outlier Ensembles Feature Bagging 2005 :class:`pyod.models.feature_bagging.FeatureBagging` :cite:`a-lazarevic2005feature`
Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 :class:`pyod.models.lscp.LSCP` :cite:`a-zhao2019lscp`
Outlier Ensembles XGBOD Extreme Boosting Based Outlier Detection **(Supervised)** 2018 :class:`pyod.models.xgbod.XGBOD` :cite:`a-zhao2018xgbod`
Outlier Ensembles LODA Lightweight On-line Detector of Anomalies 2016 :class:`pyod.models.loda.LODA` :cite:`a-pevny2016loda`
Neural Networks AutoEncoder Fully connected AutoEncoder (use reconstruction error as the outlier score) 2015 :class:`pyod.models.auto_encoder.AutoEncoder` :cite:`a-aggarwal2015outlier`
Neural Networks VAE Variational AutoEncoder (use reconstruction error as the outlier score) 2013 :class:`pyod.models.vae.VAE` :cite:`a-kingma2013auto`
Neural Networks SO_GAAL Single-Objective Generative Adversarial Active Learning 2019 :class:`pyod.models.so_gaal.SO_GAAL` :cite:`a-liu2019generative`
Neural Networks MO_GAAL Multiple-Objective Generative Adversarial Active Learning 2019 :class:`pyod.models.mo_gaal.MO_GAAL` :cite:`a-liu2019generative`
=================== ================ ====================================================================================================== ===== =================================================== ======================================================
Expand All @@ -203,6 +205,8 @@ Type Abbr Algorithm
=================== ================ ===================================================================================================== ===== =================================================== ======================================================
Outlier Ensembles Feature Bagging 2005 :class:`pyod.models.feature_bagging.FeatureBagging` :cite:`a-lazarevic2005feature`
Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 :class:`pyod.models.lscp.LSCP` :cite:`a-zhao2019lscp`
Outlier Ensembles XGBOD Extreme Boosting Based Outlier Detection **(Supervised)** 2018 :class:`pyod.models.xgbod.XGBOD` :cite:`a-zhao2018xgbod`
Outlier Ensembles LODA Lightweight On-line Detector of Anomalies 2016 :class:`pyod.models.loda.LODA` :cite:`a-pevny2016loda`
Combination Average Simple combination by averaging the scores 2015 :func:`pyod.models.combination.average` :cite:`a-aggarwal2015theoretical`
Combination Weighted Average Simple combination by averaging the scores with detector weights 2015 :func:`pyod.models.combination.average` :cite:`a-aggarwal2015theoretical`
Combination Maximization Simple combination by taking the maximum scores 2015 :func:`pyod.models.combination.maximization` :cite:`a-aggarwal2015theoretical`
Expand Down
18 changes: 18 additions & 0 deletions docs/pyod.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ pyod.models.lmdd module
:show-inheritance:
:inherited-members:

pyod.models.loda module
-----------------------

.. automodule:: pyod.models.loda
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pyod.models.lof module
----------------------

Expand Down Expand Up @@ -188,6 +197,15 @@ pyod.models.sos module
:show-inheritance:
:inherited-members:

pyod.models.vae module
----------------------

.. automodule:: pyod.models.vae
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pyod.models.xgbod module
------------------------

Expand Down
4 changes: 3 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ scikit_learn>=0.19.1
six
sphinxcontrib-bibtex
suod
tensorflow
tensorflow==1.13.2
# tensorflow comment out for large memory consumption
# https://github.com/readthedocs/readthedocs.org/issues/6537
xgboost
20 changes: 19 additions & 1 deletion docs/zreferences.bib
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ @article{liu2019generative
}

@article{zhao2019pyod,
title={PyOD: A python toolbox for scalable outlier detection},
title={{PyOD}: A python toolbox for scalable outlier detection},
author={Zhao, Yue and Nasrullah, Zain and Li, Zheng},
journal={Journal of Machine Learning Research},
volume={20},
Expand Down Expand Up @@ -294,4 +294,22 @@ @inproceedings{arning1996linear
number={50},
pages={972--981},
year={1996}
}

@article{kingma2013auto,
title={Auto-encoding variational bayes},
author={Kingma, Diederik P and Welling, Max},
journal={arXiv preprint arXiv:1312.6114},
year={2013}
}

@article{pevny2016loda,
title={Loda: Lightweight on-line detector of anomalies},
author={Pevn{\`y}, Tom{\'a}{\v{s}}},
journal={Machine Learning},
volume={102},
number={2},
pages={275--304},
year={2016},
publisher={Springer}
}
57 changes: 57 additions & 0 deletions examples/loda_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
"""Example of using LODA for outlier detection
"""
# Author: Yue Zhao <zhaoy@cmu.edu>
# License: BSD 2 clause

from __future__ import division
from __future__ import print_function

import os
import sys

# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))

from pyod.models.loda import LODA
from pyod.utils.data import generate_data
from pyod.utils.data import evaluate_print
from pyod.utils.example import visualize

if __name__ == "__main__":
contamination = 0.1 # percentage of outliers
n_train = 200 # number of training points
n_test = 100 # number of testing points

# Generate sample data
X_train, y_train, X_test, y_test = \
generate_data(n_train=n_train,
n_test=n_test,
n_features=2,
contamination=contamination,
random_state=42)

# train LOCI detector
clf_name = 'LODA'
clf = LODA()
clf.fit(X_train)

# get the prediction labels and outlier scores of the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores

# get the prediction on the test data
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# evaluate and print the results
print("\nOn Training Data:")
evaluate_print(clf_name, y_train, y_train_scores)
print("\nOn Test Data:")
evaluate_print(clf_name, y_test, y_test_scores)

# visualize the results
visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred,
y_test_pred, show_figure=True, save_figure=False)
9 changes: 5 additions & 4 deletions examples/pca_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
X_train, y_train, X_test, y_test = \
generate_data(n_train=n_train,
n_test=n_test,
n_features=2,
n_features=20,
contamination=contamination,
random_state=42)

# train PCA detector
clf_name = 'PCA'
clf = PCA()
clf = PCA(n_components=3)
clf.fit(X_train)

# get the prediction labels and outlier scores of the training data
Expand All @@ -53,5 +53,6 @@
evaluate_print(clf_name, y_test, y_test_scores)

# visualize the results
visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred,
y_test_pred, show_figure=True, save_figure=False)
# Note: the original dimension has to be 2 for visualization
# visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred,
# y_test_pred, show_figure=True, save_figure=False)
53 changes: 53 additions & 0 deletions examples/vae_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
"""Example of using Variational Auto Encoder for outlier detection
"""
# Author: Andrij Vasylenko <andrij@liverpool.ac.uk>
# License: BSD 2 clause

from __future__ import division
from __future__ import print_function

import os
import sys

# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))

from pyod.models.vae import VAE
from pyod.utils.data import generate_data
from pyod.utils.data import evaluate_print

if __name__ == "__main__":
contamination = 0.1 # percentage of outliers
n_train = 20000 # number of training points
n_test = 2000 # number of testing points
n_features = 300 # number of features

# Generate sample data
X_train, y_train, X_test, y_test = \
generate_data(n_train=n_train,
n_test=n_test,
n_features=n_features,
contamination=contamination,
random_state=42)

# train VAE detector
clf_name = 'VAE'
clf = VAE(epochs=30, contamination=contamination)
clf.fit(X_train)

# get the prediction labels and outlier scores of the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores

# get the prediction on the test data
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# evaluate and print the results
print("\nOn Training Data:")
evaluate_print(clf_name, y_train, y_train_scores)
print("\nOn Test Data:")
evaluate_print(clf_name, y_test, y_test_scores)
4 changes: 2 additions & 2 deletions pyod/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def predict_proba(self, X, method='linear'):
Returns
-------
outlier_labels : numpy array of shape (n_samples,)
outlier_probability : numpy array of shape (n_samples,)
For each observation, tells whether or not
it should be considered as an outlier according to the
fitted model. Return the outlier probability, ranging
Expand Down Expand Up @@ -414,7 +414,7 @@ def get_params(self, deep=True):
Parameters
----------
deep : boolean, optional
deep : bool, optional (default=True)
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Expand Down
2 changes: 1 addition & 1 deletion pyod/models/cblof.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class CBLOF(BaseDetector):
"""The CBLOF operator calculates the outlier score based on cluster-based
r"""The CBLOF operator calculates the outlier score based on cluster-based
local outlier factor.
CBLOF takes as an input the data set and the cluster model that was
Expand Down
2 changes: 1 addition & 1 deletion pyod/models/feature_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class FeatureBagging(BaseDetector):
RandomState instance used by `np.random`.
combination : str, optional (default='average')
the method of combination:
The method of combination:
- if 'average': take the average of all detectors
- if 'max': take the maximum scores of all detectors
Expand Down
2 changes: 1 addition & 1 deletion pyod/models/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class IForest(BaseDetector):
- If int, then draw `max_features` features.
- If float, then draw `max_features * X.shape[1]` features.
bootstrap : boolean, optional (default=False)
bootstrap : bool, optional (default=False)
If True, individual trees are fit on random subsets of the training
data sampled with replacement. If False, sampling without replacement
is performed.
Expand Down
28 changes: 15 additions & 13 deletions pyod/models/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,21 @@ def fit(self, X, y=None):
self._set_n_classes(y)

self.neigh_.fit(X)
# TODO: code cleanup
# if self.neigh_._tree is not None:
self.tree_ = self.neigh_._tree

# The code below may not be necessary
# else:
# if self.metric_params is not None:
# self.tree_ = BallTree(X, leaf_size=self.leaf_size,
# metric=self.metric,
# **self.metric_params)
# else:
# self.tree_ = BallTree(X, leaf_size=self.leaf_size,
# metric=self.metric)

# In certain cases, _tree does not exist for NearestNeighbors
# See Issue #158 (https://github.com/yzhao062/pyod/issues/158)
# n_neighbors = 100
if self.neigh_._tree is not None:
self.tree_ = self.neigh_._tree

else:
if self.metric_params is not None:
self.tree_ = BallTree(X, leaf_size=self.leaf_size,
metric=self.metric,
**self.metric_params)
else:
self.tree_ = BallTree(X, leaf_size=self.leaf_size,
metric=self.metric)

dist_arr, _ = self.neigh_.kneighbors(n_neighbors=self.n_neighbors,
return_distance=True)
Expand Down
Loading

0 comments on commit 84bad9f

Please sign in to comment.