Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用 SPU 实现主成分分析基础功能 #213

Closed
Candicepan opened this issue Jun 7, 2023 · 33 comments · Fixed by #240
Closed

使用 SPU 实现主成分分析基础功能 #213

Candicepan opened this issue Jun 7, 2023 · 33 comments · Fixed by #240
Assignees
Labels
enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan

Comments

@Candicepan
Copy link
Contributor

Candicepan commented Jun 7, 2023

此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)第一期任务 ISSUE,欢迎社区开发者参与共建~

任务介绍

  • 任务名称:使用 SPU 实现主成分分析基础功能
  • 技术方向:SPU/SML
  • 任务难度:中等

详细要求

  • 安全性(尽量少 reveal)
  • 功能性:实现算法的基本功能,包括:
    • 支持 fit, transform
    • 起码一种 svd 算法,如 full
    • 至少有一种设置主成分个数的方式和能查看主成分 variance
  • 收敛性:包含 simulator 跑出的实验数据并且证明收敛性
  • 代码规范:Python 代码需要使用 black+isort 进行格式化(流水线包含代码规范检查卡点)
  • 提交说明:关联该 isuue 并提交代码至 https://github.com/secretflow/spu/tree/main/sml
  • 特殊说明:若某个特性有特殊的限制,如需要 FM128,需要更多 fxp 等需要在注释文档中明确说明
  • 任务难点:在密态下实现 svd 算法

能力要求

  • 熟悉经典的机器学习算法
  • 熟悉 JAX 或 NumPy,可以使用 NumPy 实现算法

操作说明

@Candicepan Candicepan added enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan labels Jun 7, 2023
@hacker-jerry
Copy link
Contributor

hacker-jerry give it to me.

@hacker-jerry
Copy link
Contributor

hacker-jerry commented Jul 4, 2023

您好,我使用jax实现了一个 pca 的类原型。

from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np

import unittest
import json
import jax.numpy as jnp
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2  # 


class PCA(NamedTuple):
    components: jax.Array
    means: jax.Array
    explained_variance: jax.Array


def transform(state, x):
    x = x - state.means
    return jnp.dot(x, jnp.transpose(state.components))


def recover(state, x):
    return jnp.dot(x, state.components) + state.means


def fit(x, n_components, solver="full", rng=None):
    if solver == "full":
        return _fit_full(x, n_components)
    elif solver == "randomized":
        if rng is None:
            rng = jax.random.PRNGKey(n_components)
        return _fit_randomized(x, n_components, rng)
    else:
        raise ValueError("solver parameter is not correct")

# @partial(jax.jit, static_argnums=(1,))
def fit_and_transform(x, n_components=2):
    state = fit(x, n_components)
    return transform(state, x)


# @partial(jax.jit, static_argnames=["n_components"])
def _fit_full(x, n_components):
    n_samples, n_features = x.shape

    # Subtract the mean of the input data
    means = x.mean(axis=0, keepdims=True)
    x = x - means

    # Factorize the data matrix with singular value decomposition.
    U, S, Vt = jax.scipy.linalg.svd(x, full_matrices=False)

    # Compute the explained variance
    
    explained_variance = (S[:n_components] ** 2) / (n_samples - 1)

    # Return the transformation matrix
    A = Vt[:n_components]
    return PCA(components=A, means=means, explained_variance=explained_variance)


def _fit_randomized(x, n_components, rng, n_iter=5):
    """Randomized PCA based on Halko et al [https://doi.org/10.48550/arXiv.1007.5510]."""
    n_samples, n_features = x.shape
    means = jnp.mean(x, axis=0, keepdims=True)
    x = x - means

    # Generate n_features normal vectors of the given size
    size = jnp.minimum(2 * n_components, n_features)
    Q = jax.random.normal(rng, shape=(n_features, size))

    def step_fn(q, _):
        q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True)
        q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True)
        return q, None

    Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter)
    Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic")
    B = Q.T @ x

    _, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False)

    explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
    A = Vt[:n_components]
    return PCA(components=A, means=means, explained_variance=explained_variance)


算法可以通过

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
state = fit(X ,n_components=2)
X_pca = transform(state, X)
X_recovered = recover(state, X_pca)

直接调用。也可以通过pdd 的方式进行模拟

import spu.utils.distributed as ppd
import numpy as np

# initialized the distributed environment.
ppd.init(ppd.SAMPLE_NODES_DEF, ppd.SAMPLE_DEVICES_DEF)


def make_x():
    X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    return X

def fit_(X, n_components):
    return fit(X,n_components=n_components)

def transform_(state, X):
    return transform(state, X)

def get_variance(state):
    return state.explained_variance

x = ppd.device("P1")(make_x)()
pca_ = ppd.device("P1")(fit_)(x,n_components=2)
trans_x = ppd.device("SPU")(transform_)(pca_, x)
var = ppd.device("SPU")(get_variance)(pca_)

但是,我在使用spsim 进行模拟的时候,发生报错

sim  = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

result = spsim.sim_jax(sim, fit_and_transform, static_argnums=(1,))(X, 2)

File /opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:41, in _jax_compilation(fn, static_argnums, args, kwargs)
     37 @cached(cache=LRUCache(maxsize=128), key=_jax_compilation_key)
     38 def _jax_compilation(fn: Callable, static_argnums, args: List, kwargs: Dict):
     39     import jax
---> 41     cfn, output = jax.xla_computation(
     42         fn, return_shape=True, static_argnums=static_argnums, backend="interpreter"
     43     )(*args, **kwargs)
     44     return cfn.as_serialized_hlo_module_proto(), output

    [... skipping hidden 21 frame]

File /opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/jax/_src/interpreters/mlir.py:1155, in jaxpr_subcomp(ctx, jaxpr, tokens, consts, dim_var_values, *args)
   1153   rule = xla_fallback_lowering(eqn.primitive)
   1154 else:
-> 1155   raise NotImplementedError(
   1156       f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
   1157       f"found for platform {ctx.platform}")
   1159 eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
   1160 effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))

NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform interpreter

请问应该如何修改?

@hacker-jerry
Copy link
Contributor

By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭

@rivertalk
Copy link

NotImplementedError: MLIR translation rule for primitive 'eigh' not found for platform interpreter

Hi @hacker-jerry,这个看上去是因为 SPU并未支持JAX所有的算子(比如 eigh),请 @anakinxc 帮忙看一眼

@anakinxc
Copy link
Contributor

anakinxc commented Jul 4, 2023

Hi @hacker-jerry

感谢提供复现代码,我们研究一下

Thanks

@deadlywing
Copy link
Contributor

By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭

能麻烦提供一下本地使用的package版本么?(主要是spu和jax)

@hacker-jerry
Copy link
Contributor

spu 0.3.3b0
jax 0.4.8
jaxlib 0.4.7

@deadlywing
Copy link
Contributor

spu 0.3.3b0
jax 0.4.8
jaxlib 0.4.7

Thanks,我发现你使用ppd和spsim的方式不太一致:
在ppd中 fit 方法是明文计算的
但是在spsim中,fit部分是秘文下进行的

@hacker-jerry
Copy link
Contributor

谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改?

@deadlywing
Copy link
Contributor

谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改?

首先,我理解fit方法应该是需要能在密态下执行的,所以本质上的原因应该是SPU暂时没支持svd算子。
可能需要您自己实现一下svd算法;

  1. 建议可以先实现full solver
  2. randomized solver需要注意的点在于spu内对随机数的支持和明文下不太一致,建议把生成初始随机矩阵Q放在外面

@hacker-jerry
Copy link
Contributor

好的,谢谢!

@hacker-jerry
Copy link
Contributor

您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下:

import jax
import jax.numpy as jnp
from jax import jit
from functools import partial


class PCA:
    def __init__(self, n_components=None, tol=1e-8, max_iters=100):
        self.n_components = n_components
        self.tol = tol
        self.max_iters = max_iters
        self.components_ = None
        self.explained_variance_ = None
        self.mean_ = None

    def fit_transform(self, X):
        self.mean_ = jnp.mean(X, axis=0)
        X_centered = X - self.mean_
        cov_matrix = jnp.cov(X_centered, rowvar=False)
        eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters)

        idx = jnp.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]

        if self.n_components is None:
            self.n_components = X.shape[1]

        self.components_ = eigenvectors[:, :self.n_components]
        self.explained_variance_ = eigenvalues[:self.n_components]

        X_transformed = jnp.dot(X_centered, self.components_)

        return X_transformed, self.explained_variance_


def jacobi_eigh(A, tol, max_iters):
    n = A.shape[0]
    Q = jnp.eye(n)

    def body_fn(i, vals):
        A, Q = vals
        p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape)
        phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p]))
        rotation = jnp.eye(n)
        rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi))
        rotation = rotation.at[q, p].set(jnp.sin(phi))
        rotation = rotation.at[p, q].set(-jnp.sin(phi))
        A_prime = rotation.T @ A @ rotation
        Q_prime = Q @ rotation

        A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None)
        Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None)
        return A, Q

    A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q))

    return jnp.diag(A), Q

该函数可以通过jit编译后调用

pca = PCA(n_components=2)

pca_fit_transform = jit(pca.fit_transform, static_argnums=1)

# Prepare some data
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 使用编译后的fit_transform函数进行拟合和转换
X_transformed, explained_variance = pca_fit_transform(X)

print(explained_variance)

print(X_transformed)

但是使用spism模拟时,再次发生报错

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def fit_transform(X, n_components=None):
    pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1)
    X_transformed, explained_variance = pca_fit_transform(X)
    return X_transformed, explained_variance 

result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

报错信息如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[71], line 1
----> 1 result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152), in sim_jax..wrapper(*args, **kwargs)
    149 def outputNameGen(out_flat):
    150     return [f'out{idx}' for idx in range(len(out_flat))]
--> 152 executable, output = spu_fe.compile(
    153     spu_fe.Kind.JAX,
    154     fun,
    155     args,
    156     kwargs,
    157     in_names,
    158     in_vis,
    159     outputNameGen,
    160     static_argnums=static_argnums,
    161 )
    163 wrapper.pphlo = executable.code.decode("utf-8")
    165 out_flat = sim(executable, *args_flat)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177), in compile(kind, fn, args, kwargs, input_names, input_vis, outputNameGen, static_argnums)
    175     ir_type = "mhlo"
    176     name = repr(fn)
--> 177 mlir = spu_api.compile(ir_text, ir_type, input_vis)
    178 executable = spu_pb2.ExecutableProto(
    179     name=name,
    180     input_names=input_names,
    181     output_names=output_names,
    182     code=mlir,
    183 )
    184 return executable, output

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153), in compile(ir_text, ir_type, vis)
    150 from google.protobuf.json_format import MessageToJson
    152 # todo: rename spu_pb2.XlaMeta to IrMeta?
--> 153 return _spu_compilation(
    154     ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis))
    155 )

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737), in cached..decorator..wrapper(*args, **kwargs)
    735 except KeyError:
    736     pass  # key not found
--> 737 v = func(*args, **kwargs)
    738 try:
    739     cache[k] = v

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136), in _spu_compilation(ir_text, ir_type, json_meta)
    133 @cached(cache=LRUCache(maxsize=128))
    134 def _spu_compilation(ir_text: str, ir_type: str, json_meta: str):
    135     pp_dir = os.getenv('SPU_IR_DUMP_DIR')
--> 136     return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "")

RuntimeError: what: 
	[libspu/compiler/front_end/fe.cc:64] Run front end pipeline failed
stacktrace: 
#0 spu::compiler::FE::doit()+0x178ad5b00
#1 spu::compiler::compile()+0x178ac3ee4
#2 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()+0x178aa778c
#3 pybind11::cpp_function::dispatcher()+0x178a94ac4
#4 cfunction_call_varargs+0x104f243e0
#5 _PyObject_MakeTpCall+0x104f23af0
#6 call_function+0x105010158
#7 _PyEval_EvalFrameDefault+0x10500c83c
#8 function_code_fastcall+0x104f247b4
#9 PyVectorcall_Call+0x104f23fd8
#10 _PyEval_EvalFrameDefault+0x10500caf0
#11 _PyEval_EvalCodeWithName+0x1050057fc
#12 _PyFunction_Vectorcall+0x104f24918
#13 call_function+0x1050100c0
#14 _PyEval_EvalFrameDefault+0x10500c8b8
#15 function_code_fastcall+0x104f247b4

请问是什么原因?

@hacker-jerry
Copy link
Contributor

@anakinxc

@deadlywing
Copy link
Contributor

您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下:

import jax
import jax.numpy as jnp
from jax import jit
from functools import partial


class PCA:
    def __init__(self, n_components=None, tol=1e-8, max_iters=100):
        self.n_components = n_components
        self.tol = tol
        self.max_iters = max_iters
        self.components_ = None
        self.explained_variance_ = None
        self.mean_ = None

    def fit_transform(self, X):
        self.mean_ = jnp.mean(X, axis=0)
        X_centered = X - self.mean_
        cov_matrix = jnp.cov(X_centered, rowvar=False)
        eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters)

        idx = jnp.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]

        if self.n_components is None:
            self.n_components = X.shape[1]

        self.components_ = eigenvectors[:, :self.n_components]
        self.explained_variance_ = eigenvalues[:self.n_components]

        X_transformed = jnp.dot(X_centered, self.components_)

        return X_transformed, self.explained_variance_


def jacobi_eigh(A, tol, max_iters):
    n = A.shape[0]
    Q = jnp.eye(n)

    def body_fn(i, vals):
        A, Q = vals
        p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape)
        phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p]))
        rotation = jnp.eye(n)
        rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi))
        rotation = rotation.at[q, p].set(jnp.sin(phi))
        rotation = rotation.at[p, q].set(-jnp.sin(phi))
        A_prime = rotation.T @ A @ rotation
        Q_prime = Q @ rotation

        A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None)
        Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None)
        return A, Q

    A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q))

    return jnp.diag(A), Q

该函数可以通过jit编译后调用

pca = PCA(n_components=2)

pca_fit_transform = jit(pca.fit_transform, static_argnums=1)

# Prepare some data
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 使用编译后的fit_transform函数进行拟合和转换
X_transformed, explained_variance = pca_fit_transform(X)

print(explained_variance)

print(X_transformed)

但是使用spism模拟时,再次发生报错

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def fit_transform(X, n_components=None):
    pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1)
    X_transformed, explained_variance = pca_fit_transform(X)
    return X_transformed, explained_variance 

result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

报错信息如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[71], line 1
----> 1 result = spsim.sim_jax(sim_aby, fit_transform,  static_argnums=(1,))(X,2)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/simulation.py:152), in sim_jax..wrapper(*args, **kwargs)
    149 def outputNameGen(out_flat):
    150     return [f'out{idx}' for idx in range(len(out_flat))]
--> 152 executable, output = spu_fe.compile(
    153     spu_fe.Kind.JAX,
    154     fun,
    155     args,
    156     kwargs,
    157     in_names,
    158     in_vis,
    159     outputNameGen,
    160     static_argnums=static_argnums,
    161 )
    163 wrapper.pphlo = executable.code.decode("utf-8")
    165 out_flat = sim(executable, *args_flat)

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/utils/frontend.py:177), in compile(kind, fn, args, kwargs, input_names, input_vis, outputNameGen, static_argnums)
    175     ir_type = "mhlo"
    176     name = repr(fn)
--> 177 mlir = spu_api.compile(ir_text, ir_type, input_vis)
    178 executable = spu_pb2.ExecutableProto(
    179     name=name,
    180     input_names=input_names,
    181     output_names=output_names,
    182     code=mlir,
    183 )
    184 return executable, output

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:153), in compile(ir_text, ir_type, vis)
    150 from google.protobuf.json_format import MessageToJson
    152 # todo: rename spu_pb2.XlaMeta to IrMeta?
--> 153 return _spu_compilation(
    154     ir_text, ir_type, MessageToJson(spu_pb2.XlaMeta(inputs=vis))
    155 )

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/cachetools/__init__.py:737), in cached..decorator..wrapper(*args, **kwargs)
    735 except KeyError:
    736     pass  # key not found
--> 737 v = func(*args, **kwargs)
    738 try:
    739     cache[k] = v

File [/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/sf/lib/python3.8/site-packages/spu/api.py:136), in _spu_compilation(ir_text, ir_type, json_meta)
    133 @cached(cache=LRUCache(maxsize=128))
    134 def _spu_compilation(ir_text: str, ir_type: str, json_meta: str):
    135     pp_dir = os.getenv('SPU_IR_DUMP_DIR')
--> 136     return libspu.compile(ir_text, ir_type, json_meta, pp_dir or "")

RuntimeError: what: 
	[libspu/compiler/front_end/fe.cc:64] Run front end pipeline failed
stacktrace: 
#0 spu::compiler::FE::doit()+0x178ad5b00
#1 spu::compiler::compile()+0x178ac3ee4
#2 pybind11::cpp_function::initialize<>()::{lambda()#1}::__invoke()+0x178aa778c
#3 pybind11::cpp_function::dispatcher()+0x178a94ac4
#4 cfunction_call_varargs+0x104f243e0
#5 _PyObject_MakeTpCall+0x104f23af0
#6 call_function+0x105010158
#7 _PyEval_EvalFrameDefault+0x10500c83c
#8 function_code_fastcall+0x104f247b4
#9 PyVectorcall_Call+0x104f23fd8
#10 _PyEval_EvalFrameDefault+0x10500caf0
#11 _PyEval_EvalCodeWithName+0x1050057fc
#12 _PyFunction_Vectorcall+0x104f24918
#13 call_function+0x1050100c0
#14 _PyEval_EvalFrameDefault+0x10500c8b8
#15 function_code_fastcall+0x104f247b4

请问是什么原因?

hello,不能跑的原因主要是eigh的实现里用到了三角函数,spu当前没有实现,所以报错了;

PLUS,你eigh的实现应该也有问题,我运行了你的eigh

def test_eigh():
    X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    X_centered = X - jnp.mean(X, axis=0)
    cov_matrix = jnp.cov(X_centered, rowvar=False)

    eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, 1e-8, 1000)
    # print(eigenvalues)
    # print(eigenvectors)
    print(cov_matrix @ eigenvectors)
    print(eigenvalues * eigenvectors)

    print()

    eigenvalues, eigenvectors = eigh(cov_matrix)
    # print(eigenvalues)
    # print(eigenvectors)
    print(cov_matrix @ eigenvectors)
    print(eigenvalues * eigenvectors)
[[2.2865321e+02 5.4445304e-06 2.5582359e+02]
 [2.2865321e+02 5.4445304e-06 2.5582359e+02]
 [2.2865321e+02 5.4445304e-06 2.5582359e+02]]
[[1.4377324e+02 1.2175019e-21 2.0135764e+02]
 [2.4747901e-06 1.2621775e-29 3.4660027e-06]
 [1.6085759e+02 9.3170446e-22 2.2528442e+02]]

[[ 4.1723251e-07 -2.9802322e-07  1.5588457e+01]
 [ 4.1723251e-07 -2.9802322e-07  1.5588457e+01]
 [ 4.1723251e-07 -2.9802322e-07  1.5588457e+01]]
[[-4.0569081e-08 -1.6165853e-06  1.5588456e+01]
 [-1.1870292e-07  1.1621918e-06  1.5588454e+01]
 [ 1.5927199e-07  4.5439356e-07  1.5588456e+01]]

@deadlywing
Copy link
Contributor

jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:

  1. PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
  2. 一般地,可以考虑QR分解,不需要半正定的条件

仅供参考~

@hacker-jerry
Copy link
Contributor

hacker-jerry commented Jul 12, 2023

jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:

  1. PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
  2. 一般地,可以考虑QR分解,不需要半正定的条件

仅供参考~

谢谢您的建议,基于此,我重新实现了一下,代码如下:

import jax
import jax.numpy as jnp
from jax import random

class PCA:
    def __init__(self, n_components):
        self.n_components = n_components
        self.mean = None
        self.components = None
        self.variances = None

    def fit(self, X):
        self.mean = jnp.mean(X, axis=0)
        X = X - self.mean

        cov_matrix = jnp.cov(X, rowvar=False)

        L = jnp.linalg.cholesky(cov_matrix)

        q, r = jnp.linalg.qr(L)

        eigvals = jnp.diag(r)

        idx = jnp.argsort(eigvals)[::-1][:self.n_components]

        self.components = q[:, idx]

        self.variances = eigvals[idx]


    def transform(self, X):
        X = X - self.mean
        return jnp.dot(X, self.components)


def fit_and_transform(X, n_components):
    pca = PCA(n_components)
    pca.fit(X)
    return pca.transform(X)


X = random.randint(random.PRNGKey(0), (10,3), 0, 10)

fit_and_transform_jit = jit(fit_and_transform, static_argnums=1)

X_transformed = fit_and_transform_jit(X, 2)

print(X_transformed)

您看是否符合要求?
这次的代码通过了sispm模拟。

@deadlywing
Copy link
Contributor

jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:

  1. PCA对cov矩阵计算特征值变换,而cov是半正定的,可以用cholesky分解计算特征变换
  2. 一般地,可以考虑QR分解,不需要半正定的条件

仅供参考~

谢谢您的建议,基于此,我重新实现了一下,代码如下:

import jax
import jax.numpy as jnp
from jax import random

class PCA:
    def __init__(self, n_components):
        self.n_components = n_components
        self.mean = None
        self.components = None
        self.variances = None

    def fit(self, X):
        self.mean = jnp.mean(X, axis=0)
        X = X - self.mean

        cov_matrix = jnp.cov(X, rowvar=False)

        L = jnp.linalg.cholesky(cov_matrix)

        q, r = jnp.linalg.qr(L)

        eigvals = jnp.diag(r)

        idx = jnp.argsort(eigvals)[::-1][:self.n_components]

        self.components = q[:, idx]

        self.variances = eigvals[idx]


    def transform(self, X):
        X = X - self.mean
        return jnp.dot(X, self.components)


def fit_and_transform(X, n_components):
    pca = PCA(n_components)
    pca.fit(X)
    return pca.transform(X)


X = random.randint(random.PRNGKey(0), (10,3), 0, 10)

fit_and_transform_jit = jit(fit_and_transform, static_argnums=1)

X_transformed = fit_and_transform_jit(X, 2)

print(X_transformed)

您看是否符合要求? 这次的代码通过了sispm模拟。

Sorry, 这应该是spsim的bug,实际上无论cholesky还是qr应该都无法真实的执行,执行到那两个函数的时候似乎python进程会被直接关闭,我们后续应该会修复这个bug。(所以我很好奇,您运行spsim真的能得到PCA transform后的矩阵么?)

所以,你也需要自己手动实现cholesky分解或qr分解。

最后,麻烦您后面提交pr的时候,用注释的方式标记一下之前的实现中,因为spu不支持算子而无法运行的实现方式,后续我们增加这些算子以后可以重新考察这些实现~

感谢!

@hacker-jerry
Copy link
Contributor

的确是运行成功了,
image
您也可以测试一下,

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

result = spsim.sim_jax(sim_aby, fit_and_transform,  static_argnums=(1,))(X,2)

@deadlywing
Copy link
Contributor

的确是运行成功了, image 您也可以测试一下,

import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu

sim_aby = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

result = spsim.sim_jax(sim_aby, fit_and_transform,  static_argnums=(1,))(X,2)

Thanks, 我本地运行会一直卡住,我需要check一下原因. 另外,麻烦您运行一下下面这个代码,看是否会raise除0错误.

def test_run_eigh():
    X = jnp.array(np.random.rand(6, 3))
    cov_matrix = jnp.cov(X, rowvar=False)

    sim_aby = spsim.Simulator.simple(
        3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
    )

    print(cov_matrix)
    print(jnp.linalg.det(cov_matrix))
    print(jnp.linalg.cholesky(cov_matrix))
    print(spsim.sim_jax(sim_aby, jnp.linalg.cholesky)(cov_matrix))
    print(1 / 0)

@hacker-jerry
Copy link
Contributor

我本地运行的版本是

secretflow                    0.8.2b1
sf-heu                        0.4.3b3
spu                           0.3.2b12
jax                           0.4.8
jaxlib                        0.4.7

上述代码运行结果有除 0 报错
image

@deadlywing
Copy link
Contributor

image
image

我在jupyter上运行的话,也会报错...

@hacker-jerry
Copy link
Contributor

image image

我在jupyter上运行的话,也会报错...

您看看spu降一下级试一下?

@deadlywing
Copy link
Contributor

image image
我在jupyter上运行的话,也会报错...

您看看spu降一下级试一下?

好的,我试试;
btw,您是用linux系统不?

@hacker-jerry
Copy link
Contributor

我是用的m1 mac

@deadlywing
Copy link
Contributor

我是用的m1 mac

我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看
您可以先直接使用这两个api去实现pca吧~

建议您实现以后可以:

  1. 检查是否满足特征值分解的定义
  2. 检查fit后的方差等与sklearn的pca是否一致

感谢!

@tarantula-leo
Copy link
Contributor

我是用的m1 mac

我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看 您可以先直接使用这两个api去实现pca吧~

建议您实现以后可以:

  1. 检查是否满足特征值分解的定义
  2. 检查fit后的方差等与sklearn的pca是否一致

感谢!

你好 想问下cholesky分解和qr分解是已经在SPU中支持了还是使用spsim模拟的bug?

@anakinxc
Copy link
Contributor

我是用的m1 mac

hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12

@hacker-jerry
Copy link
Contributor

hacker-jerry commented Jul 12, 2023

我是用的m1 mac

hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12

我把jax 和 jaxlib 升级到 0.4.12了,运行没有问题

image

@hacker-jerry
Copy link
Contributor

我是用的m1 mac

hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12

我把jax 和 jaxlib 升级到 0.4.12了,运行没有问题

image

但是spu还是原来版本

@deadlywing
Copy link
Contributor

那看来就是jax最新版本不太适配了,,那您就先用现在的版本先开发吧

@hacker-jerry
Copy link
Contributor

ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改?

@deadlywing
Copy link
Contributor

ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改?

感谢快速响应,可以先参考一下这个kmeans的PR;
#235
一般是一个文件用于实现算法逻辑(jax only),一个文件用于spsim模拟测试,一个文件做emulation测试;

BTW:麻烦请在spsim模拟测试的那个文件中同时提交一下和明文sklearn的结果对比(可以写在不同的unittest里)

Thanks!

@hacker-jerry
Copy link
Contributor

hacker-jerry commented Jul 14, 2023

Already solved this issue @Candicepan .
#240

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request OSCP SecretFlow Open Source Contribution Plan
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

6 participants