diff --git a/src/uapca.ts b/src/uapca.ts index 4ce97ad..5ddb439 100644 --- a/src/uapca.ts +++ b/src/uapca.ts @@ -6,15 +6,6 @@ export interface Distribution { covariance(): Matrix; } -export interface Projection { - /** - * Projects a distribution onto a lower dimensional subspace - * defined by the dimensions of the `projectionMatrix`. - * @param {Matrix} projectionMatrix - Defined as column matrix. - */ - project(projectionMatrix: Matrix): Projection; -} - export interface AffineTransformation { /** * Performs an affine transformation of the distribution @@ -22,9 +13,16 @@ export interface AffineTransformation { * @param {Matrix} b - Translation defined as row vector. */ affineTransformation(A: Matrix, b: Matrix): AffineTransformation; + + /** + * Projects a distribution onto a lower dimensional subspace + * defined by the dimensions of the `projectionMatrix`. + * @param {Matrix} projectionMatrix - Defined as column matrix. + */ + project(projectionMatrix: Matrix): AffineTransformation; } -export class MultivariateNormal implements AffineTransformation, Distribution, Projection { +export class MultivariateNormal implements AffineTransformation, Distribution { private meanVec: Matrix; private covMat: Matrix; public constructor(meanVec: Array | Matrix, covMat: Array> | Matrix) { @@ -108,25 +106,23 @@ export function arithmeticMean(matrices: Array): Matrix { return sum.div(N); } -function centering(distributions: Array): Matrix { - const v = arithmeticMean(distributions.map(d => d.mean())); - return outerProduct(v); -} - export class UaPCA { private lengths: Array; private vectors: Matrix; // row matrix! + private mean: Matrix; - private constructor(lengths: Array, vectors: Matrix) { + private constructor(lengths: Array, vectors: Matrix, mean: Matrix) { this.lengths = lengths; this.vectors = vectors; + this.mean = mean; } public static fit( distributions: Array, scale: number = 1.0, ): UaPCA { - const center: Matrix = centering(distributions); + const empiricalMean = arithmeticMean(distributions.map(d => d.mean())); + const center: Matrix = outerProduct(empiricalMean); const empericalCov: Matrix = arithmeticMean(distributions.map(d => { return outerProduct(d.mean()).add(Matrix.mul(d.covariance(), scale * scale)) .sub(center); @@ -140,7 +136,7 @@ export class UaPCA { const pairs: Array<[number, Array]> = evals.map((e, i) => [e, evecs[i]]); const comps = pairs.sort((a, b) => b[0] - a[0]); - return new UaPCA(comps.map(d => d[0]), new Matrix(comps.map(v => v[1]))); + return new UaPCA(comps.map(d => d[0]), new Matrix(comps.map(v => v[1])), empiricalMean); } public aligned(): UaPCA { @@ -150,7 +146,7 @@ export class UaPCA { vecs.setRow(i, vecs.getRowVector(i).mul(-1)); } } - return new UaPCA(this.lengths, vecs); + return new UaPCA(this.lengths, vecs, this.mean); } public eigenvalues(nDims?: number): Array { @@ -164,11 +160,15 @@ export class UaPCA { } public transform( - objects: Array, + objects: Array, components: number, - ): Array { + ): Array { const projMat = this.projectionMatrix(components); - return objects.map(d => d.project(projMat.transpose())); + const centered = objects.map(d => d.affineTransformation( + Matrix.eye(this.mean.columns, this.mean.columns), + Matrix.mul(this.mean, -1) + )); + return centered.map(d => d.project(projMat.transpose())); } }