diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 83bbfe20..98841df3 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -40,10 +40,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b1.tar.gz", ], - strip_prefix = "yacl-0.4.5b3", - sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1", + strip_prefix = "yacl-0.4.5b1", + sha256 = "28064053b9add0db8e1e8e648421a0579f1d3e7ee8a4bbd7bd5959cb59598088", ) def _libpsi(): @@ -169,10 +169,10 @@ def _com_github_pybind11(): http_archive, name = "pybind11", build_file = "@pybind11_bazel//:pybind11.BUILD", - sha256 = "bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7", - strip_prefix = "pybind11-2.12.0", + sha256 = "51631e88960a8856f9c497027f55c9f2f9115cafb08c0005439838a05ba17bfc", + strip_prefix = "pybind11-2.13.1", urls = [ - "https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz", + "https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.tar.gz", ], ) diff --git a/docs/reference/np_op_status.json b/docs/reference/np_op_status.json index 1272847c..20ed81a7 100644 --- a/docs/reference/np_op_status.json +++ b/docs/reference/np_op_status.json @@ -1 +1 @@ -[{"name": "abs", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "add", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "nan"}, {"name": "arcsin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arcsinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "arctanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "[-1, 1] nan"}, {"name": "argmax", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "argmin", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equiv", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_1d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_2d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_3d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_and", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_not", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_or", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_xor", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cbrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.cbrt"}, {"name": "ceil", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conjugate", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "copysign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "shift"}, {"name": "cos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "deg2rad", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "divmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "ediff1d", "dtypes": ["int32"], "status": "Status.Pass", "note": ""}, {"name": "equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "expm1", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fabs", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "fix", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "float_power", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "floor", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "floor_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "gcd", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "greater", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "greater_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "heaviside", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "hypot", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "i0", "dtypes": ["float32"], "status": "Status.Failed", "note": "accuracy"}, {"name": "imag", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "invert", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isclose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "iscomplex", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isfinite", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isnan", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isneginf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isposinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isreal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isrealobj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "kron", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "lcm", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "ldexp", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "left_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log10", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logical_and", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_not", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_or", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_xor", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "maximum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32"], "status": "Status.PassNoGen", "note": ""}, {"name": "minimum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "modf", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "multiply", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nanargmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanargmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanmean", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanprod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nansum", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "negative", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nextafter", "dtypes": ["float32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "not_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "outer", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "polyval", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "positive", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "power", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "prod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rad2deg", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "ravel", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "real", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "reciprocal", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "remainder", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "right_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rint", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "sign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "signbit", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sqrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "subtract", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sum", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "tan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.tan"}, {"name": "tanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "transpose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "true_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "trunc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "unwrap", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}] \ No newline at end of file +[{"name": "abs", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "add", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "angle", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arccosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "nan"}, {"name": "arcsin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arcsinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "arctan2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "arctanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "[-1, 1] nan"}, {"name": "argmax", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "argmin", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "array_equiv", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_1d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_2d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "atleast_3d", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_and", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_count", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_not", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_or", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "bitwise_xor", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cbrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.cbrt"}, {"name": "ceil", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "conjugate", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "copysign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "shift"}, {"name": "cos", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "cosh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "deg2rad", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "divmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "ediff1d", "dtypes": ["int32"], "status": "Status.Pass", "note": ""}, {"name": "equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "exp2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "expm1", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fabs", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "fix", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "float_power", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "floor", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "floor_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "fmod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "gcd", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "greater", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "greater_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "heaviside", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "hypot", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "i0", "dtypes": ["float32"], "status": "Status.Failed", "note": "accuracy"}, {"name": "imag", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "invert", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isclose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "iscomplex", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isfinite", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isnan", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isneginf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isposinf", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "isreal", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "isrealobj", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "kron", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "lcm", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "secret while"}, {"name": "ldexp", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "left_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "less_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log10", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log1p", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "log2", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logaddexp2", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "logical_and", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_not", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_or", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "logical_xor", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "maximum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mean", "dtypes": ["float32"], "status": "Status.PassNoGen", "note": ""}, {"name": "minimum", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "mod", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "modf", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "multiply", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nanargmax", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanargmin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanmean", "dtypes": ["float32"], "status": "Status.UnSupport", "note": ""}, {"name": "nanprod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "nansum", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.UnSupport", "note": ""}, {"name": "negative", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "nextafter", "dtypes": ["float32"], "status": "Status.Failed", "note": "IEEE-754"}, {"name": "not_equal", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "outer", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "polyval", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "positive", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "power", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "prod", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rad2deg", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "ravel", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "real", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "reciprocal", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "remainder", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "right_shift", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "rint", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}, {"name": "sign", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "signbit", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sin", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sinh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sqrt", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "square", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "subtract", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "sum", "dtypes": ["int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "tan", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.SysError", "note": "stablehlo.tan"}, {"name": "tanh", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "transpose", "dtypes": ["bool_", "float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "true_divide", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "trunc", "dtypes": ["float32", "int16", "int32", "uint16", "uint32"], "status": "Status.Pass", "note": ""}, {"name": "unwrap", "dtypes": ["float32"], "status": "Status.Pass", "note": ""}] \ No newline at end of file diff --git a/docs/reference/np_op_status.md b/docs/reference/np_op_status.md index d69f20a0..ab3539a4 100644 --- a/docs/reference/np_op_status.md +++ b/docs/reference/np_op_status.md @@ -293,6 +293,20 @@ Please check *Supported Dtypes* as well. - uint16 - uint32 +## bitwise_count + +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_count.html +### Status + +**PASS** +Please check *Supported Dtypes* as well. +### Supported Dtypes + +- int16 +- int32 +- uint16 +- uint32 + ## bitwise_not JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_not.html diff --git a/docs/reference/pphlo_doc.rst b/docs/reference/pphlo_doc.rst index 751541e1..148d417e 100644 --- a/docs/reference/pphlo_doc.rst +++ b/docs/reference/pphlo_doc.rst @@ -1,9 +1,9 @@ -PPHlo API reference +PPHLO API reference =================== -PPHlo is short for (SPU High level ops), it's the assembly language of SPU. +PPHLO is short for (Privacy-Preserving High-Level Operations), it's the assembly language of SPU. -PPHlo is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. +PPHLO is built on `MLIR `_ infrastructure, the concrete ops definition could be found :spu_code_host:`here `. Op List ~~~~~~~ diff --git a/docs/reference/pphlo_op_doc.md b/docs/reference/pphlo_op_doc.md index f6afc268..2839c9a1 100644 --- a/docs/reference/pphlo_op_doc.md +++ b/docs/reference/pphlo_op_doc.md @@ -747,7 +747,7 @@ Ref https://www.tensorflow.org/xla/operation_semantics#dot. Traits: `AlwaysSpeculatableImplTrait` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` @@ -1626,55 +1626,63 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values values -### `pphlo.power` (spu::pphlo::PowOp) +### `pphlo.popcnt` (spu::pphlo::PopcntOp) -_Power operator_ +_Popcnt operator, ties away from zero_ Syntax: ``` -operation ::= `pphlo.power` $lhs `,` $rhs attr-dict - `:` custom(type($lhs), type($rhs), type($result)) +operation ::= `pphlo.popcnt` $operand attr-dict `:` custom(type($operand), type($result)) ``` -Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor. +Performs element-wise count of the number of bits set in the `operand` tensor and produces a `result` tensor. -Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power +Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt -Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape` +Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` -Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + +
AttributeMLIR TypeDescription
bits::mlir::IntegerAttr64-bit signless integer attribute
+ #### Operands: | Operand | Description | | :-----: | ----------- | -| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values -| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values #### Results: | Result | Description | | :----: | ----------- | -| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values -### `pphlo.prefer_a` (spu::pphlo::PreferAOp) +### `pphlo.power` (spu::pphlo::PowOp) -_Prefer AShare operator_ +_Power operator_ Syntax: ``` -operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom(type($operand), type($result)) +operation ::= `pphlo.power` $lhs `,` $rhs attr-dict + `:` custom(type($lhs), type($rhs), type($result)) ``` -Convert input to AShare if possible. +Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor. -Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` +Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + +Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape` Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` @@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values +| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values #### Results: @@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor. Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign +PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0 + Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType` Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + +
AttributeMLIR TypeDescription
ignore_zero::mlir::BoolAttrbool attribute
+ #### Operands: | Operand | Description | @@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` @@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)` +Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)` Effects: `MemoryEffects::Effect{}` diff --git a/docs/reference/runtime_config.md b/docs/reference/runtime_config.md index f556bce4..836095cc 100644 --- a/docs/reference/runtime_config.md +++ b/docs/reference/runtime_config.md @@ -179,8 +179,9 @@ The SPU runtime configuration. | Field | Type | Description | | ----- | ---- | ----------- | | server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. | -| session_id | [ string](#string) | if empty, use link id as session id. | | adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. | +| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. | +| server_public_key | [ bytes](#bytes) | server's public key | diff --git a/libspu/compiler/common/compilation_context.cc b/libspu/compiler/common/compilation_context.cc index ef22871c..553fe913 100644 --- a/libspu/compiler/common/compilation_context.cc +++ b/libspu/compiler/common/compilation_context.cc @@ -23,7 +23,7 @@ namespace { void SPUErrorHandler(void * /*use_data*/, const char *reason, bool /*gen_crash_diag*/) { - SPU_THROW(reason); + SPU_THROW("{}", reason); } } // namespace diff --git a/libspu/compiler/common/ir_printer_config.cc b/libspu/compiler/common/ir_printer_config.cc index 47dff64a..f3c15ff2 100644 --- a/libspu/compiler/common/ir_printer_config.cc +++ b/libspu/compiler/common/ir_printer_config.cc @@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *, if (ec.value() != 0) { spdlog::error("Open file {} failed, error = {}", file_name.c_str(), ec.message()); + return; } print_callback(f); } @@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *, if (ec.value() != 0) { spdlog::error("Open file {} failed, error = {}", file_name.c_str(), ec.message()); + return; } print_callback(f); } diff --git a/libspu/compiler/front_end/fe.cc b/libspu/compiler/front_end/fe.cc index 682f1c5b..e560c772 100644 --- a/libspu/compiler/front_end/fe.cc +++ b/libspu/compiler/front_end/fe.cc @@ -54,6 +54,8 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) { module = mlir::parseSourceString(source.ir_txt(), ctx_->getMLIRContext()); + SPU_ENFORCE(module, "MLIR parser failure"); + // Convert stablehlo to mhlo first mlir::PassManager pm(ctx_->getMLIRContext()); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 5845af57..363bd76c 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { auto module_config = xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options); if (!module_config.status().ok()) { - SPU_THROW(module_config.status().message()); + SPU_THROW("{}", module_config.status().message()); } auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config); if (!module.status().ok()) { - SPU_THROW(module.status().message()); + SPU_THROW("{}", module.status().message()); } xla::runHloPasses((*module).get()); @@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { auto status = importer.Import(**module); if (!status.ok()) { - SPU_THROW(status.message()); + SPU_THROW("{}", status.message()); } return mlir_hlo; diff --git a/libspu/device/api.cc b/libspu/device/api.cc index 5535ad85..1b27f881 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -229,7 +229,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) { (void)use_data; (void)gen_crash_diag; - SPU_THROW(reason); + SPU_THROW("{}", reason); } std::mutex ErrorHandlerMutex; diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 657b81fc..1c67c997 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -19,6 +19,7 @@ #include "libspu/core/bit_utils.h" #include "libspu/core/context.h" #include "libspu/core/trace.h" +#include "libspu/core/vectorize.h" #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/polymorphic.h" #include "libspu/kernel/hal/prot_wrapper.h" @@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) { return _get_owner(x) == _get_owner(y); } +void _hint_nbits(const Value &a, size_t nbits) { + if (a.storage_type().isa()) { + const_cast(a.storage_type()).as()->setNbits(nbits); + } +} + // generate inverse permutation Index _inverse_index(const Index &p) { Index q(p.size()); @@ -531,20 +538,29 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm, std::vector _bit_decompose(SPUContext *ctx, const spu::Value &x, int64_t valid_bits) { auto x_bshare = _prefer_b(ctx, x); - const auto k1 = _constant(ctx, 1U, x.shape()); - std::vector rets; size_t nbits = valid_bits != -1 ? static_cast(valid_bits) : x_bshare.storage_type().as()->nbits(); - rets.reserve(nbits); + _hint_nbits(x_bshare, nbits); + if (ctx->hasKernel("b2a_disassemble")) { + auto ret = + dynDispatch>(ctx, "b2a_disassemble", x_bshare); + return ret; + } + + const auto k1 = _constant(ctx, 1U, x.shape()); + std::vector rets_b; + rets_b.reserve(nbits); for (size_t bit = 0; bit < nbits; ++bit) { auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit); - auto lowest_bit = _and(ctx, x_bshare_shift, k1); - rets.emplace_back(_prefer_a(ctx, lowest_bit)); + rets_b.push_back(_and(ctx, x_bshare_shift, k1)); } - return rets; + std::vector rets_a; + vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a), + [&](const Value &x) { return _prefer_a(ctx, x); }); + return rets_a; } // Generate vector of bit decomposition of sorting keys diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index 45bed576..90d51818 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -233,6 +233,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const { ctx->pushOutput(WrapValue(z)); } +void DisassembleKernel::evaluate(KernelEvalContext* ctx) const { + const auto& in = ctx->getParam(0); + auto z = proc(ctx, UnwrapValue(in)); + + std::vector wrapped(z.size()); + for (size_t idx = 0; idx < z.size(); ++idx) { + wrapped[idx] = WrapValue(z[idx]); + } + ctx->pushOutput(wrapped); +}; + void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const { auto target = ctx->getParam(0); auto s = ctx->getParam(1); diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index 4a4c29d8..12383391 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -217,4 +217,12 @@ class ConcateKernel : public Kernel { int64_t axis) const = 0; }; +class DisassembleKernel : public Kernel { + public: + void evaluate(KernelEvalContext* ctx) const override; + + virtual std::vector proc(KernelEvalContext* ctx, + const NdArrayRef& in) const = 0; +}; + } // namespace spu::mpc diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index 35d8c701..dce1857f 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -42,6 +42,26 @@ static NdArrayRef wrap_and_bb(SPUContext* ctx, const NdArrayRef& x, return UnwrapValue(and_bb(ctx, WrapValue(x), WrapValue(y))); } +// TODO: Move to some common place +PtType getBacktype(size_t nbits) { + if (nbits <= 8) { + return PT_U8; + } + if (nbits <= 16) { + return PT_U16; + } + if (nbits <= 32) { + return PT_U32; + } + if (nbits <= 64) { + return PT_U64; + } + if (nbits <= 128) { + return PT_U128; + } + SPU_THROW("invalid number of bits={}", nbits); +} + NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); auto* comm = ctx->getState(); @@ -90,6 +110,9 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { return r_a; } +// TODO(jimi): pack {numel * nbits} to fully make use of undelying storage to +// save communications. If implemented, B2A_Disassemble kernel is also no longer +// needed NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { const auto field = x.eltype().as()->field(); @@ -105,6 +128,7 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, const auto numel = x.numel(); const auto rand_numel = numel * static_cast(nbits); + const PtType backtype = getBacktype(nbits); auto randbits = beaver->RandBit(field, rand_numel); SPU_ENFORCE(static_cast(randbits.size()) == @@ -119,32 +143,125 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, // algorithm begins. // Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) - std::vector x_xor_r(numel); - - pforeach(0, numel, [&](int64_t idx) { - // use _r[i*nbits, (i+1)*nbits) to construct rb[i] - U mask = 0; - for (int64_t bit = 0; bit < nbits; ++bit) { - mask += (_randbits[idx * nbits + bit] & 0x1) << bit; - } - x_xor_r[idx] = _x[idx] ^ mask; + DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + using V = ScalarT; + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + V mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit; + } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + NdArrayView _res(res); + pforeach(0, numel, [&](int64_t idx) { + _res[idx] = 0; + for (int64_t bit = 0; bit < nbits; bit++) { + auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) + << bit; + } else { + _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; + } + } + }); }); + }); + + return res; +} - // open c = x ^ r - x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); - - NdArrayView _res(res); - pforeach(0, numel, [&](int64_t idx) { - _res[idx] = 0; - for (int64_t bit = 0; bit < nbits; bit++) { - auto c_i = (x_xor_r[idx] >> bit) & 0x1; - if (comm->getRank() == 0) { - _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) - << bit; - } else { - _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; +// Reference: +// III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) +// +// Analysis: +// Online Latency: 1 (x_xor_r reveal) +// Communication: one element bits for one element +// Vectorization: yes +// +// HighLevel Intuition: +// Since: X = sum: Xi * 2^i +// If we have A, then we can construct A = sum: A * 2^i. +// +// The problem is that we only have B in hand. Details for how to +// construct A from B: +// - trusted third party choose a random bit r, where r == 0 or r == 1. +// - trusted third party send A to parties +// - parties compute B from A +// - parties xor_open c = Xi ^ r = open(B ^ B), Xi is still safe due +// to protection from r. +// - parties compute: = c + (1-2c)* +// A = 1 - A if c == 1, i.e. Xi != r +// A = A if c == 0, i.e. Xi == r +// i.e. A = c + (1-2c) * A +// +// Online Communication: +// = 1 (xor open) + +// Disassemble BShr to AShr bit-by-bit +// Input: BShr +// Return: a vector of k AShr, k is the valid bits of BShr +std::vector B2A_Disassemble::proc(KernelEvalContext* ctx, + const NdArrayRef& x) const { + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + + const int64_t nbits = x.eltype().as()->nbits(); + SPU_ENFORCE((size_t)nbits > 0 && (size_t)nbits <= SizeOf(field) * 8, + "invalid nbits={}", nbits); + + const auto numel = x.numel(); + const auto rand_numel = numel * static_cast(nbits); + const PtType backtype = getBacktype(nbits); + + auto randbits = beaver->RandBit(field, rand_numel); + + std::vector res; + res.reserve(nbits); + for (int64_t idx = 0; idx < nbits; ++idx) { + res.emplace_back(makeType(field), x.shape()); + } + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = ring2k_t; + + absl::Span _randbits(randbits.data(), rand_numel); + NdArrayView _x(x); + + DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { + using V = ScalarT; + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + V mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (static_cast(_randbits[idx * nbits + bit]) & 0x1) << bit; } - } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + pforeach(0, numel, [&](int64_t idx) { + pforeach(0, nbits, [&](int64_t bit) { + NdArrayView _res(res[bit]); + auto c_i = static_cast(x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] = (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]); + } else { + _res[idx] = ((1 - c_i * 2) * _randbits[idx * nbits + bit]); + } + }); + }); }); }); diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h index bc249998..891a23cd 100644 --- a/libspu/mpc/semi2k/conversion.h +++ b/libspu/mpc/semi2k/conversion.h @@ -73,6 +73,21 @@ class B2A_Randbit : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; }; +class B2A_Disassemble : public DisassembleKernel { + public: + static constexpr char kBindName[] = "b2a_disassemble"; + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { + return ce::K() * (ce::N() - 1) // Open bit masked value + ; + } + + std::vector proc(KernelEvalContext* ctx, + const NdArrayRef& x) const override; +}; + // Note: current only for 2PC. class MsbA2B : public UnaryKernel { public: diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index a7d3b7e1..3acd8344 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -50,21 +50,23 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot()->addState(ctx->config(), lctx); ctx->prot() ->regKernel< - semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // - semi2k::NotA, // - semi2k::AddAP, semi2k::AddAA, // - semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // - semi2k::MatMulAP, semi2k::MatMulAA, // - semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, - semi2k::ARShiftB, // - semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // - semi2k::B2P, semi2k::P2B, semi2k::A2B, semi2k::B2A_Randbit, // - semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, - semi2k::BitrevB, // - semi2k::BitIntlB, semi2k::BitDeintlB, // - semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, - semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // - semi2k::EqualAA, semi2k::EqualAP, semi2k::BeaverCacheKernel>(); + semi2k::P2A, semi2k::A2P, semi2k::A2V, semi2k::V2A, // + semi2k::NotA, // + semi2k::AddAP, semi2k::AddAA, // + semi2k::MulAP, semi2k::MulAA, semi2k::SquareA, // + semi2k::MatMulAP, semi2k::MatMulAA, // + semi2k::LShiftA, semi2k::LShiftB, semi2k::RShiftB, // + semi2k::ARShiftB, // + semi2k::CommonTypeB, semi2k::CommonTypeV, semi2k::CastTypeB, // + semi2k::B2P, semi2k::P2B, // + semi2k::A2B, semi2k::B2A_Randbit, semi2k::B2A_Disassemble, // + semi2k::AndBP, semi2k::AndBB, semi2k::XorBP, semi2k::XorBB, // + semi2k::BitrevB, // + semi2k::BitIntlB, semi2k::BitDeintlB, // + semi2k::RandA, semi2k::RandPermM, semi2k::PermAM, semi2k::PermAP, // + semi2k::InvPermAM, semi2k::InvPermAP, semi2k::InvPermAV, // + semi2k::EqualAA, semi2k::EqualAP, // + semi2k::BeaverCacheKernel>(); if (ctx->config().trunc_allow_msb_error()) { ctx->prot()->regKernel(); diff --git a/requirements.txt b/requirements.txt index bea5a7d1..ac8fd6cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ grpcio>=1.42.0,!=1.48.0 -numpy>=1.22.0, < 2 +numpy>=1.22.0 protobuf>=4, <5 cloudpickle>=2.0.0 multiprocess>=0.70.12.2 cachetools>=5.0.0 -jax[cpu]>=0.4.16, <=0.4.26 # FIXME +jax[cpu]>=0.4.16, <=0.4.26 # FIXME: Jax 0.4.26+ select perf issue termcolor>=2.0.0 diff --git a/spu/utils/distributed_impl.py b/spu/utils/distributed_impl.py index 79b13426..ebd6a703 100644 --- a/spu/utils/distributed_impl.py +++ b/spu/utils/distributed_impl.py @@ -723,7 +723,10 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): fn_name = repr(fn) - import jax.extend.linear_util as lu + try: + import jax.extend.linear_util as lu + except ImportError: + import jax.linear_util as lu # fallback from jax._src import api_util as japi_util from jax.tree_util import tree_map, tree_flatten diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index e1a4c24c..fc20cd0c 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -115,13 +115,32 @@ def _jax_compilation( register_backend_factory('interpreter', xla_back, priority=-100) - fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) + jax_version = jax.__version_info__ + + if jax_version[0] > 1 or jax_version[1] > 4 or jax_version[2] > 29: + # xla_computation is deprecated since 0.4.30, move to new api + lowered = ( + jax.jit( + fn, + static_argnums=static_argnums, + static_argnames=static_argnames, + keep_unused=True, + ) + .trace(*args, **kwargs) + .lower(lowering_platforms=('interpreter',)) + ) + return ( + lowered.compiler_ir('hlo').as_serialized_hlo_module_proto(), + lowered.out_info, + ) + else: + fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) - cfn, output = jax.xla_computation( - fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" - )(*args, **kwargs) + cfn, output = jax.xla_computation( + fn, return_shape=True, static_argnums=static_argnums, backend="interpreter" + )(*args, **kwargs) - return cfn.as_serialized_hlo_module_proto(), output + return cfn.as_serialized_hlo_module_proto(), output ## Frontend patches diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 8df6185c..592ccaff 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -17,7 +17,11 @@ from typing import Callable import jax -import jax.extend.linear_util as jax_lu # Moved in jax 0.4.16 + +try: + import jax.extend.linear_util as jax_lu +except ImportError: + import jax.linear_util as jax_lu # fallback import jax.numpy as jnp import numpy as np from jax._src import api_util as japi_util