From 29fba5fb0107d12b03ce2608d74b049485b7e3fb Mon Sep 17 00:00:00 2001 From: garyschulte Date: Wed, 2 Oct 2024 13:33:11 -0700 Subject: [PATCH 1/2] add parallel g1/g2 msm gnark-crypto impl Signed-off-by: garyschulte --- gnark/gnark-jni/gnark-eip-2537.go | 125 ++++++++++++++++-- .../besu/nativelib/gnark/LibGnarkEIP2537.java | 16 ++- 2 files changed, 129 insertions(+), 12 deletions(-) diff --git a/gnark/gnark-jni/gnark-eip-2537.go b/gnark/gnark-jni/gnark-eip-2537.go index 49bd8fbc..fdedf33b 100644 --- a/gnark/gnark-jni/gnark-eip-2537.go +++ b/gnark/gnark-jni/gnark-eip-2537.go @@ -10,8 +10,10 @@ import ( "math/big" "reflect" "unsafe" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" ) const ( @@ -167,6 +169,54 @@ func eip2537blsG1MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn return nonMontgomeryMarshalG1(result, javaOutputBuf, errorBuf) } +//export eip2537blsG1MultiExpParallel +func eip2537blsG1MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int { + inputLen := int(cInputLen) + errorLen := int(cOutputLen) + + // Convert error C pointers to Go slices + errorBuf := castBuffer(javaErrorBuf, errorLen) + + if inputLen == 0 { + copy(errorBuf, "invalid input parameters, invalid number of pairs\x00") + return 1 + } + + if inputLen % (EIP2537PreallocateForG1 + EIP2537PreallocateForScalar) != 0 { + copy(errorBuf, "invalid input parameters, invalid input length for G1 multiplication\x00") + return 1 + } + + // Convert input C pointers to Go slice + input := castBufferToSlice(unsafe.Pointer(javaInputBuf), inputLen) + + var exprCount = inputLen / (EIP2537PreallocateForG1 + EIP2537PreallocateForScalar) + + g1Points := make([]bls12381.G1Affine, exprCount) + scalars := make([]fr.Element, exprCount) + + for i := 0 ; i < exprCount ; i++ { + _, err := g1AffineDecodeInSubGroupVal(&g1Points[i], input[i*160 : (i*160)+128]) + if err != nil { + copy(errorBuf, err.Error()) + return 1 + } + + scalars[i].SetBytes(input[(i*160)+128 : (i+1)*160]) + } + + var affineResult bls12381.G1Affine + // leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit + _, err := affineResult.MultiExp(g1Points, scalars, ecc.MultiExpConfig{}) + if err != nil { + copy(errorBuf, err.Error()) + return 1 + } + + // marshal the resulting point and encode directly to the output buffer + return nonMontgomeryMarshalG1(&affineResult, javaOutputBuf, errorBuf) +} + //export eip2537blsG2Add func eip2537blsG2Add(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int { inputLen := int(cInputLen) @@ -289,6 +339,58 @@ func eip2537blsG2MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn return nonMontgomeryMarshalG2(result, javaOutputBuf, errorBuf) } +//export eip2537blsG2MultiExpParallel +func eip2537blsG2MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int { + inputLen := int(cInputLen) + errorLen := int(cOutputLen) + + // Convert error C pointers to Go slices + errorBuf := castBuffer(javaErrorBuf, errorLen) + + if inputLen == 0 { + copy(errorBuf, "invalid input parameters, invalid number of pairs\x00") + return 1 + } + + if inputLen % (EIP2537PreallocateForG2 + EIP2537PreallocateForScalar) != 0 { + copy(errorBuf, "invalid input parameters, invalid input length for G2 multiplication\x00") + return 1 + } + + // Convert input C pointers to Go slice + input := castBufferToSlice(unsafe.Pointer(javaInputBuf), inputLen) + + var exprCount = inputLen / (EIP2537PreallocateForG2 + EIP2537PreallocateForScalar) + + g2Points := make([]bls12381.G2Affine, exprCount) + scalars := make([]fr.Element, exprCount) + + for i := 0 ; i < exprCount ; i++ { + _, err := g2AffineDecodeInSubGroupVal(&g2Points[i], input[i*288 : (i*288)+256]) + if err != nil { + copy(errorBuf, err.Error()) + return 1 + } + + scalars[i].SetBytes(input[(i*288)+256 : (i+1)*288]) + } + + var affineResult bls12381.G2Affine + // leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit + _, err := affineResult.MultiExp(g2Points, scalars, ecc.MultiExpConfig{}) + if err != nil { + copy(errorBuf, err.Error()) + return 1 + } + + // marshal the resulting point and encode directly to the output buffer + return nonMontgomeryMarshalG2(&affineResult, javaOutputBuf, errorBuf) +} + + + + + //export eip2537blsPairing func eip2537blsPairing(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int { inputLen := int(cInputLen) @@ -439,25 +541,24 @@ func hasWrongG1Padding(input []byte) bool { func hasWrongG2Padding(input []byte) bool { return !isZero(input[:16]) || !isZero(input[64:80] )|| !isZero(input[128:144]) || !isZero(input[192:208]) } - - func g1AffineDecodeInSubGroup(input []byte) (*bls12381.G1Affine, error) { + var g1 bls12381.G1Affine + return g1AffineDecodeInSubGroupVal(&g1, input) +} + +func g1AffineDecodeInSubGroupVal(g1 *bls12381.G1Affine, input []byte) (*bls12381.G1Affine, error) { if hasWrongG1Padding(input) { return nil, ErrMalformedPointPadding } - var g1x, g1y fp.Element - err := g1x.SetBytesCanonical(input[16:64]) + err := g1.X.SetBytesCanonical(input[16:64]) if err != nil { return nil, err } - err = g1y.SetBytesCanonical(input[80:128]) + err = g1.Y.SetBytesCanonical(input[80:128]) if err != nil { return nil, err } - // construct g1affine directly rather than unmarshalling - g1 := &bls12381.G1Affine{X: g1x, Y: g1y} - // do explicit subgroup check if (!g1.IsInSubGroup()) { if (!g1.IsOnCurve()) { @@ -493,11 +594,15 @@ func g1AffineDecodeOnCurve(input []byte) (*bls12381.G1Affine, error) { } func g2AffineDecodeInSubGroup(input []byte) (*bls12381.G2Affine, error) { + var g2 bls12381.G2Affine + return g2AffineDecodeInSubGroupVal(&g2, input) +} + +func g2AffineDecodeInSubGroupVal(g2 *bls12381.G2Affine, input []byte) (*bls12381.G2Affine, error) { if hasWrongG2Padding(input) { return nil, ErrMalformedPointPadding } - var g2 bls12381.G2Affine err := g2.X.A0.SetBytesCanonical(input[16:64]) if err != nil { return nil, err @@ -522,7 +627,7 @@ func g2AffineDecodeInSubGroup(input []byte) (*bls12381.G2Affine, error) { if (!g2.IsInSubGroup()) { return nil, ErrSubgroupCheckFailed } - return &g2, nil; + return g2, nil; } func g2AffineDecodeOnCurve(input []byte) (*bls12381.G2Affine, error) { diff --git a/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java b/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java index 7cf245d8..bf407357 100644 --- a/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java +++ b/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java @@ -61,7 +61,7 @@ public static int eip2537_perform_operation( o_len.setValue(128); break; case BLS12_G1MULTIEXP_OPERATION_SHIM_VALUE: - ret = eip2537blsG1MultiExp(i, output, err, i_len, + ret = eip2537blsG1MultiExpParallel(i, output, err, i_len, EIP2537_PREALLOCATE_FOR_RESULT_BYTES, EIP2537_PREALLOCATE_FOR_ERROR_BYTES); o_len.setValue(128); @@ -79,7 +79,7 @@ public static int eip2537_perform_operation( o_len.setValue(256); break; case BLS12_G2MULTIEXP_OPERATION_SHIM_VALUE: - ret = eip2537blsG2MultiExp(i, output, err, i_len, + ret = eip2537blsG2MultiExpParallel(i, output, err, i_len, EIP2537_PREALLOCATE_FOR_RESULT_BYTES, EIP2537_PREALLOCATE_FOR_ERROR_BYTES); o_len.setValue(256); @@ -134,6 +134,12 @@ public static native int eip2537blsG1MultiExp( byte[] error, int inputSize, int output_len, int err_len); + public static native int eip2537blsG1MultiExpParallel( + byte[] input, + byte[] output, + byte[] error, + int inputSize, int output_len, int err_len); + public static native int eip2537blsG2Add( byte[] input, byte[] output, @@ -152,6 +158,12 @@ public static native int eip2537blsG2MultiExp( byte[] error, int inputSize, int output_len, int err_len); + public static native int eip2537blsG2MultiExpParallel( + byte[] input, + byte[] output, + byte[] error, + int inputSize, int output_len, int err_len); + public static native int eip2537blsPairing( byte[] input, byte[] output, From 31fb0c7e8fee2fe75232956364a8b920a4b981bc Mon Sep 17 00:00:00 2001 From: garyschulte Date: Fri, 4 Oct 2024 10:08:08 -0700 Subject: [PATCH 2/2] add a configurable NbTasks for degree-of-parallelism for msm Signed-off-by: garyschulte --- gnark/gnark-jni/gnark-eip-2537.go | 8 ++++---- .../besu/nativelib/gnark/LibGnarkEIP2537.java | 18 ++++++++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/gnark/gnark-jni/gnark-eip-2537.go b/gnark/gnark-jni/gnark-eip-2537.go index fdedf33b..55289055 100644 --- a/gnark/gnark-jni/gnark-eip-2537.go +++ b/gnark/gnark-jni/gnark-eip-2537.go @@ -170,7 +170,7 @@ func eip2537blsG1MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn } //export eip2537blsG1MultiExpParallel -func eip2537blsG1MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int { +func eip2537blsG1MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int, nbTasks C.int) C.int { inputLen := int(cInputLen) errorLen := int(cOutputLen) @@ -207,7 +207,7 @@ func eip2537blsG1MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.c var affineResult bls12381.G1Affine // leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit - _, err := affineResult.MultiExp(g1Points, scalars, ecc.MultiExpConfig{}) + _, err := affineResult.MultiExp(g1Points, scalars, ecc.MultiExpConfig{NbTasks: int(nbTasks)}) if err != nil { copy(errorBuf, err.Error()) return 1 @@ -340,7 +340,7 @@ func eip2537blsG2MultiExp(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cIn } //export eip2537blsG2MultiExpParallel -func eip2537blsG2MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int) C.int { +func eip2537blsG2MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen, cOutputLen, cErrorLen C.int, nbTasks C.int) C.int { inputLen := int(cInputLen) errorLen := int(cOutputLen) @@ -377,7 +377,7 @@ func eip2537blsG2MultiExpParallel(javaInputBuf, javaOutputBuf, javaErrorBuf *C.c var affineResult bls12381.G2Affine // leave nbTasks unset, allow golang to use available cpu cores as the parallelism limit - _, err := affineResult.MultiExp(g2Points, scalars, ecc.MultiExpConfig{}) + _, err := affineResult.MultiExp(g2Points, scalars, ecc.MultiExpConfig{NbTasks: int(nbTasks)}) if err != nil { copy(errorBuf, err.Error()) return 1 diff --git a/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java b/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java index bf407357..6d21a068 100644 --- a/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java +++ b/gnark/src/main/java/org/hyperledger/besu/nativelib/gnark/LibGnarkEIP2537.java @@ -9,6 +9,9 @@ public class LibGnarkEIP2537 implements Library { @SuppressWarnings("WeakerAccess") public static final boolean ENABLED; + // zero implies 'default' degree of parallelism, which is the number of cpu cores available + private static int degreeOfMSMParallelism = 0; + static { boolean enabled; try { @@ -63,7 +66,8 @@ public static int eip2537_perform_operation( case BLS12_G1MULTIEXP_OPERATION_SHIM_VALUE: ret = eip2537blsG1MultiExpParallel(i, output, err, i_len, EIP2537_PREALLOCATE_FOR_RESULT_BYTES, - EIP2537_PREALLOCATE_FOR_ERROR_BYTES); + EIP2537_PREALLOCATE_FOR_ERROR_BYTES, + degreeOfMSMParallelism); o_len.setValue(128); break; case BLS12_G2ADD_OPERATION_SHIM_VALUE: @@ -81,7 +85,8 @@ public static int eip2537_perform_operation( case BLS12_G2MULTIEXP_OPERATION_SHIM_VALUE: ret = eip2537blsG2MultiExpParallel(i, output, err, i_len, EIP2537_PREALLOCATE_FOR_RESULT_BYTES, - EIP2537_PREALLOCATE_FOR_ERROR_BYTES); + EIP2537_PREALLOCATE_FOR_ERROR_BYTES, + degreeOfMSMParallelism); o_len.setValue(256); break; case BLS12_PAIR_OPERATION_SHIM_VALUE: @@ -138,7 +143,8 @@ public static native int eip2537blsG1MultiExpParallel( byte[] input, byte[] output, byte[] error, - int inputSize, int output_len, int err_len); + int inputSize, int output_len, int err_len, + int nbTasks); public static native int eip2537blsG2Add( byte[] input, @@ -162,7 +168,8 @@ public static native int eip2537blsG2MultiExpParallel( byte[] input, byte[] output, byte[] error, - int inputSize, int output_len, int err_len); + int inputSize, int output_len, int err_len, + int nbTasks); public static native int eip2537blsPairing( byte[] input, @@ -182,4 +189,7 @@ public static native int eip2537blsMapFp2ToG2( byte[] error, int inputSize, int output_len, int err_len); + public static void setDegreeOfMSMParallelism(int nbTasks) { + degreeOfMSMParallelism = nbTasks; + } }