Skip to content

Commit

Permalink
Allow regint registers as argument in exported functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Dec 4, 2024
1 parent e87000f commit 78a4b65
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 1 deletion.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ EXPORT_VM = $(patsubst %.cpp, %.o, $(wildcard Machines/export-*.cpp))

export-trunc.x: Machines/export-ring.o
export-sort.x: Machines/export-ring.o
export-msort.x: Machines/export-ring.o
export-a2b.x: GC/AtlasSecret.o Machines/SPDZ.o Machines/SPDZ2^64+64.o $(GC_SEMI) $(TINIER) $(EXPORT_VM) GC/Rep4Secret.o GC/Rep4Prep.o $(FHEOFFLINE)
export-b2a.x: Machines/export-ring.o

Expand Down
5 changes: 5 additions & 0 deletions Processor/FunctionArgument.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@ void FunctionArgument::check_type(const string& type_string)
"return type mismatch: " + get_type_string() + "/"
+ type_string);
}

bool FunctionArgument::has_reg_type(const char* reg_type)
{
return this->reg_type == string(reg_type);
}
11 changes: 11 additions & 0 deletions Processor/FunctionArgument.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ class FunctionArgument
}
}

/**
* Argument with integer array.
*/
FunctionArgument(vector<long>& values) :
data(values.data()), size(values.size()), n_bits(0),
reg_type("ci"), memory(false)
{
}

size_t get_size()
{
return size;
Expand Down Expand Up @@ -116,6 +125,8 @@ class FunctionArgument
return memory;
}

bool has_reg_type(const char* reg_type);

template<class T>
T& get_value(size_t index)
{
Expand Down
8 changes: 7 additions & 1 deletion Processor/Machine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,19 +443,25 @@ void Machine<sint, sgf2n>::run_function(const string& name,
arguments[i].get_value<vector<typename sint::bit_type>>(j).at(
k);
}
else
else if (arguments[i].has_reg_type("s"))
{
auto& value = arguments[i].get_value<sint>(j);
if (arguments[i].get_memory())
Mp.MS[arg_regs.at(i) + j] = value;
else
processor.Procp.get_S()[arg_regs.at(i) + j] = value;
}
else
{
assert(arguments[i].has_reg_type("ci"));
processor.write_Ci(arg_regs.at(i) + j, arguments[i].get_value<long>(j));
}
}

run_tape(0, tape_number, 0, N.num_players());
join_tape(0);

assert(result.has_reg_type("s"));
for (size_t j = 0; j < result.get_size(); j++)
result.get_value<sint>(j) = processor.Procp.get_S()[return_reg + j];

Expand Down
8 changes: 8 additions & 0 deletions Programs/Source/export-msort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@export
def sort(x, key_indices):
print_ln('x=%s', x.reveal())
print_ln('key_indices=%s', key_indices)
res = x.sort(key_indices=key_indices)
print_ln('res=%s', x.reveal())

sort(sint.Matrix(500, 2), regint(0, size=1))
59 changes: 59 additions & 0 deletions Utils/export-msort.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* export-sort.cpp
*
*/

#include "Machines/minimal.hpp"

int main(int argc, const char** argv)
{
assert(argc > 1);
int my_number = atoi(argv[1]);
int port_base = 9999;
Names N(my_number, 3, "localhost", port_base);

typedef Rep3Share2<64> share_type;
Machine<share_type> machine(N);

int n = 1000;

vector<share_type> inputs;
for (int i = 0; i < n; i++)
{
int j = (i % 2) ? i : (n - i);
inputs.push_back(share_type::constant(j, my_number));
}

vector<long> key_indices = {1};

vector<FunctionArgument> args = {{inputs, true}, key_indices};
FunctionArgument res;

machine.run_function("sort", res, args);

Opener<share_type> MC(machine.get_player(), machine.get_sint_mac_key());
MC.init_open();
for (auto& x : inputs)
MC.prepare_open(x);
MC.exchange();

if (my_number == 0)
{
cout << "res: ";
for (int i = 0; i < 10; i++)
cout << MC.finalize_open() << " ";
cout << endl;
}
else
{
for (int i = 0; i < n; i++)
{
auto x = MC.finalize_open();
if (not i % 2 and x != i + 2)
{
cerr << "error at " << i << ": " << x << endl;
exit(1);
}
}
}
}

0 comments on commit 78a4b65

Please sign in to comment.