Skip to content

Commit

Permalink
Merge pull request #153 from cirnoooo123/main
Browse files Browse the repository at this point in the history
implement secret sharing multiply
  • Loading branch information
SongY123 authored May 8, 2024
2 parents f9270eb + ee8b794 commit cc2a5e7
Show file tree
Hide file tree
Showing 16 changed files with 418 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.hufudb.openhufu.core.config;

public class OpenHuFuConfig {
public static final int CLIENT_THREAD_NUM = 4; // client side thread pool size
public static final int CLIENT_THREAD_NUM = 20; // client side thread pool size
public static final int SERVER_THREAD_NUM = 4; // server side thread pool size
public static final long RPC_TIME_OUT = 60000; // time out when waiting for response in ms
public static final int ZK_TIME_OUT = 6000; // time out of zk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ private boolean isPrivacyRangeJoin(BinaryPlan plan) {
if (plan.getJoinCond().getModifier().equals(Modifier.PUBLIC)) {
return false;
}
if (!plan.getJoinCond().getCondition().getIn(0).getModifier().equals(Modifier.PUBLIC)) {
if (!plan.getJoinCond().getCondition().getInList().isEmpty()
&& !plan.getJoinCond().getCondition().getIn(0).getModifier().equals(Modifier.PUBLIC)) {
throw new OpenHuFuException(ErrorCode.RANGE_JOIN_LEFT_TABLE_NOT_PUBLIC);
}
return plan.getJoinCond().getCondition().getStr().equals("dwithin");
Expand All @@ -122,7 +123,8 @@ private boolean isPrivacyKNNJoin(BinaryPlan plan) {
if (plan.getJoinCond().getModifier().equals(Modifier.PUBLIC)) {
return false;
}
if (!plan.getJoinCond().getCondition().getIn(0).getModifier().equals(Modifier.PUBLIC)) {
if (!plan.getJoinCond().getCondition().getInList().isEmpty()
&& !plan.getJoinCond().getCondition().getIn(0).getModifier().equals(Modifier.PUBLIC)) {
throw new OpenHuFuException(ErrorCode.RANGE_JOIN_LEFT_TABLE_NOT_PUBLIC);
}
return plan.getJoinCond().getCondition().getStr().equals("knn");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public static List<Pair<OwnerClient, QueryPlanProto>> generateLeafOwnerPlans(Ope
}
List<Pair<OwnerClient, String>> tableClients = client.getTableClients(plan.getTableName());
List<Pair<OwnerClient, QueryPlanProto>> ownerContext = new ArrayList<>();
TaskInfo.Builder taskInfo = TaskInfo.newBuilder().setTaskId(client.getTaskId());
for (Pair<OwnerClient, String> entry : tableClients) {
taskInfo.addParties(entry.getLeft().getParty().getPartyId());
}
builder.setTaskInfo(taskInfo);
for (Pair<OwnerClient, String> entry : tableClients) {
builder.setTableName(entry.getRight());
ownerContext.add(MutablePair.of(entry.getLeft(), builder.build()));
Expand Down
4 changes: 3 additions & 1 deletion mpc/src/main/java/com/hufudb/openhufu/mpc/ProtocolType.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ public enum ProtocolType {
SS("SS", 101, true),
HASH_PSI("PSI", 200, true),
ABY("ABY", 300, true),
SECRET_UNION("SECRET_UNION", 400, true);
SECRET_UNION("SECRET_UNION", 400, true),
SECRET_MULTIPLY("SECRET_MULTIPLY", 401, true);

private static final ImmutableMap<Integer, ProtocolType> MAP;

static {
Expand Down
54 changes: 54 additions & 0 deletions mpc/src/main/java/com/hufudb/openhufu/mpc/multiply/MulCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.hufudb.openhufu.mpc.multiply;

import java.util.Random;

public class MulCache {
private final int[][] ran;
private long[][] val;
private boolean isInit;
private final static int TIME_OUT = 60000;

public MulCache(int n) {
ran = new int[n][2];
val = new long[n][2];
isInit = false;
Random random = new Random(System.currentTimeMillis());
for (int i = 0; i < n; ++i) {
ran[i][0] = random.nextInt(128) + 1;
ran[i][1] = random.nextInt(128) + 1;
}
}

public int getRan(int idx, boolean isFirst) {
return isFirst ? ran[idx][0] : ran[idx][1];
}

public synchronized void setVal(long[][] val) {
this.val = val;
isInit = true;
this.notifyAll();
}

public int ranSum(int idx) {
int sum = 0;
for (int i = 0; i < ran.length; ++i) {
if (i != idx) {
sum += ran[i][0] * ran[i][1];
}
}
return sum;
}

public synchronized long getVal(int idx, boolean isFirst) {
int i = 0;
while (!isInit && i < 20) {
try {
this.wait(TIME_OUT);
} catch (InterruptedException e) {
e.printStackTrace();
}
++i;
}
return isFirst ? val[idx][0] : val[idx][1];
}
}
142 changes: 142 additions & 0 deletions mpc/src/main/java/com/hufudb/openhufu/mpc/multiply/SecretMultiply.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package com.hufudb.openhufu.mpc.multiply;

import com.google.common.collect.ImmutableList;
import com.hufudb.openhufu.mpc.ProtocolException;
import com.hufudb.openhufu.mpc.ProtocolType;
import com.hufudb.openhufu.mpc.RpcProtocolExecutor;
import com.hufudb.openhufu.mpc.codec.OpenHuFuCodec;
import com.hufudb.openhufu.rpc.Rpc;
import com.hufudb.openhufu.rpc.utils.DataPacket;
import com.hufudb.openhufu.rpc.utils.DataPacketHeader;

import java.util.List;

//Dan Bogdanov, Sven Laur, and Jan Willemson. 2008. Sharemind: A Framework for Fast Privacy-Preserving Computations. In ESORICS. 192–206.

public class SecretMultiply extends RpcProtocolExecutor {
private MulCache mc;
private List<Integer> parties;
private long taskId;
private long localSum;
private boolean isLeader;
private int idx;
private long[][] vals;
public SecretMultiply(Rpc rpc) {
super(rpc, ProtocolType.SS);
}

private void fillVals(long u, long v) {
// LOG.info("fillVals start in {}", ownId);
for (int i = 0; i < parties.size(); i++) {
if (i == idx) {
continue;
}
int t = i == (idx - 1 + parties.size()) % parties.size()?
(idx - 2 + parties.size()) % parties.size(): (idx - 1 + parties.size()) % parties.size();
send(mc.getRan(i, true), 0 + generateStepID(t, i), parties.get(t));
send(mc.getRan(i, false), 1 + generateStepID(i, t), parties.get(i));
}
for (int i = 0; i < parties.size(); i++) {
if (i == idx) {
continue;
}
vals[i][1] = u + receive(0 + generateStepID(idx, i), parties.get(getThirdEndpoint(idx, i, parties.size())));
vals[i][0] = v + receive(1 + generateStepID(idx, i), parties.get(getThirdEndpoint(i, idx, parties.size())));
}
mc.setVal(vals);
}
private int getThirdEndpoint(int i, int j, int n) {
return (((i + 1) % n) == j) ? (i + 2) % n : (i + 1) % n;
}

private int generateStepID(int i, int j) {
return (i * parties.size() + j) * 2;
}

private void calShares(long u, long v) {
LOG.info("calShares start in {}", ownId);
for (int i = 0; i < parties.size(); i++) {
if (i == idx) {
continue;
}
send(mc.getVal(i, true), mc.getVal(i, false), 2, parties.get(i));
}
for (int i = 0; i < parties.size(); i++) {
if (i == idx) {
continue;
}
List<Long> res = receive2(2, parties.get(i));
localSum += -mc.getVal(i, false) * res.get(0) + u * res.get(0) + v * res.get(1);
}
}

private long sumShares() {
LOG.info("sumShares start in {}", ownId);
long globalSum = 0;
if (isLeader) {
for (int i = 0; i < parties.size(); i++) {
if (i == idx) {
globalSum += localSum;
}
else {
globalSum += receive(3, parties.get(i));
}
}
}
else {
send(localSum, 3, parties.get(0));
}
return globalSum;
}



private void send(long value, int stepID, int partyID) {
LOG.info("send to {}, {}", partyID, stepID);
DataPacketHeader header = new DataPacketHeader(taskId, getProtocolTypeId(), stepID, ownId, partyID);
rpc.send(DataPacket.fromByteArrayList(header, ImmutableList.of(OpenHuFuCodec.encodeLong(value))));
}

private void send(long value1, long value2, int stepID, int partyID) {
DataPacketHeader header = new DataPacketHeader(taskId, getProtocolTypeId(), stepID, ownId, partyID);
rpc.send(DataPacket.fromByteArrayList(header,
ImmutableList.of(OpenHuFuCodec.encodeLong(value1), OpenHuFuCodec.encodeLong(value2))));
}

private long receive(int stepID, int partyID) {
LOG.info("receive {}, {}", partyID, stepID);
final DataPacketHeader expect = new DataPacketHeader(taskId, getProtocolTypeId(), stepID, partyID, ownId);
DataPacket packet = rpc.receive(expect);
return OpenHuFuCodec.decodeLong(packet.getPayload().get(0));
// return 0;
}

private List<Long> receive2(int stepID, int partyID) {
final DataPacketHeader expect = new DataPacketHeader(taskId, getProtocolTypeId(), stepID, partyID, ownId);
DataPacket packet = rpc.receive(expect);
return ImmutableList.of(OpenHuFuCodec.decodeLong(packet.getPayload().get(0)),
OpenHuFuCodec.decodeLong(packet.getPayload().get(1)));
}

/**
* @param args[0] input value1
* @param args[1] input value2
* @return result of ColumnType for the first party, 0 for other parties
*/
@Override
public Object run(long taskId, List<Integer> parties, Object... args) throws ProtocolException {
long u = (long) args[0];
long v = (long) args[1];
this.mc = new MulCache(parties.size());
this.parties = parties;
this.taskId = taskId;
this.localSum = 0;
this.isLeader = ownId == parties.get(0);
this.idx = parties.indexOf(ownId);
this.vals = new long[parties.size()][2];
localSum += u * v + mc.ranSum(idx);
fillVals(u, v);
calShares(u, v);
return sumShares();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,6 @@ public Object run(long taskId, List<Integer> parties, Object... args) throws Pro
} else {
followerProcedure(localDataSet);
}
return null;
return EmptyDataSet.INSTANCE;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package com.hufudb.openhufu.mpc.secretMultiply;

import com.google.common.collect.ImmutableList;
import com.hufudb.openhufu.mpc.multiply.SecretMultiply;
import com.hufudb.openhufu.rpc.Party;
import com.hufudb.openhufu.rpc.grpc.OpenHuFuOwnerInfo;
import com.hufudb.openhufu.rpc.grpc.OpenHuFuRpc;
import com.hufudb.openhufu.rpc.grpc.OpenHuFuRpcManager;
import io.grpc.Channel;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.locationtech.jts.geom.GeometryFactory;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;

@RunWith(JUnit4.class)
public class SecretMultiplyTest {
public final static GeometryFactory geoFactory = new GeometryFactory();
@Rule
public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();

OpenHuFuRpcManager manager;
ExecutorService threadpool = Executors.newFixedThreadPool(5);

@Before
public void setUp() throws IOException {
String ownerName0 = InProcessServerBuilder.generateName();
String ownerName1 = InProcessServerBuilder.generateName();
String ownerName2 = InProcessServerBuilder.generateName();
String ownerName3 = InProcessServerBuilder.generateName();
String ownerName4 = InProcessServerBuilder.generateName();
Party owner0 = new OpenHuFuOwnerInfo(0, ownerName0);
Party owner1 = new OpenHuFuOwnerInfo(1, ownerName1);
Party owner2 = new OpenHuFuOwnerInfo(2, ownerName2);
Party owner3 = new OpenHuFuOwnerInfo(3, ownerName3);
Party owner4 = new OpenHuFuOwnerInfo(4, ownerName4);
List<Party> parties = ImmutableList.of(owner0, owner1, owner2, owner3, owner4);
List<Channel> channels = Arrays.asList(
grpcCleanup.register(InProcessChannelBuilder.forName(ownerName0).directExecutor().build()),
grpcCleanup.register(InProcessChannelBuilder.forName(ownerName1).directExecutor().build()),
grpcCleanup.register(InProcessChannelBuilder.forName(ownerName2).directExecutor().build()),
grpcCleanup.register(InProcessChannelBuilder.forName(ownerName3).directExecutor().build()),
grpcCleanup.register(InProcessChannelBuilder.forName(ownerName4).directExecutor().build()));
manager = new OpenHuFuRpcManager(parties, channels);
OpenHuFuRpc rpc0 = (OpenHuFuRpc) manager.getRpc(0);
OpenHuFuRpc rpc1 = (OpenHuFuRpc) manager.getRpc(1);
OpenHuFuRpc rpc2 = (OpenHuFuRpc) manager.getRpc(2);
OpenHuFuRpc rpc3 = (OpenHuFuRpc) manager.getRpc(3);
OpenHuFuRpc rpc4 = (OpenHuFuRpc) manager.getRpc(4);
rpc0.connect();
rpc1.connect();
rpc2.connect();
rpc3.connect();
rpc4.connect();
Server server0 = InProcessServerBuilder.forName(ownerName0).directExecutor()
.addService(rpc0.getgRpcService()).build().start();
Server server1 = InProcessServerBuilder.forName(ownerName1).directExecutor()
.addService(rpc1.getgRpcService()).build().start();
Server server2 = InProcessServerBuilder.forName(ownerName2).directExecutor()
.addService(rpc2.getgRpcService()).build().start();
Server server3 = InProcessServerBuilder.forName(ownerName3).directExecutor()
.addService(rpc3.getgRpcService()).build().start();
Server server4 = InProcessServerBuilder.forName(ownerName4).directExecutor()
.addService(rpc4.getgRpcService()).build().start();
grpcCleanup.register(server0);
grpcCleanup.register(server1);
grpcCleanup.register(server2);
grpcCleanup.register(server3);
grpcCleanup.register(server4);
}

void testMultiply(long taskId, List<SecretMultiply> executors, List<Integer> integers, long ans)
throws InterruptedException, ExecutionException {
List<Integer> parties = executors.stream().map(e -> e.getOwnId()).collect(Collectors.toList());
List<Future<Object>> futures = new ArrayList<>();
for (int i = 0; i < executors.size(); ++i) {
final SecretMultiply s = executors.get(i);
final int int1 = integers.get(2 * i);
final int int2 = integers.get(2 * i + 1);
futures.add(threadpool.submit(new Callable<Object>() {
@Override
public Object call() throws Exception {
return s.run(taskId, parties, (long) int1, (long) int2);
}
}));
}
long res = (long) futures.get(0).get();
assertEquals(ans, res);
}

@Test
public void testSecretMultiply() throws InterruptedException, ExecutionException {
Random random = new Random();
OpenHuFuRpc rpc0 = (OpenHuFuRpc) manager.getRpc(0);
OpenHuFuRpc rpc1 = (OpenHuFuRpc) manager.getRpc(1);
OpenHuFuRpc rpc2 = (OpenHuFuRpc) manager.getRpc(2);
OpenHuFuRpc rpc3 = (OpenHuFuRpc) manager.getRpc(3);
OpenHuFuRpc rpc4 = (OpenHuFuRpc) manager.getRpc(4);
List<OpenHuFuRpc> rpcs = ImmutableList.of(rpc0, rpc1, rpc2, rpc3, rpc4);
List<SecretMultiply> executors =
rpcs.stream().map(rpc -> new SecretMultiply(rpc)).collect(Collectors.toList());
List<Integer> integers = new ArrayList<>();
long u = 0;
long v = 0;
for (int i = 0; i < 5; i++) {
int int1 = random.nextInt(128);
int int2 = random.nextInt(128);
integers.add(int1);
integers.add(int2);
u += int1;
v += int2;
}
testMultiply(0, executors, integers, u * v);
}
}
Loading

0 comments on commit cc2a5e7

Please sign in to comment.