diff --git a/Makefile b/Makefile index ff04f1e1e..448bd3b5a 100644 --- a/Makefile +++ b/Makefile @@ -167,6 +167,7 @@ EXPORT_VM = $(patsubst %.cpp, %.o, $(wildcard Machines/export-*.cpp)) export-trunc.x: Machines/export-ring.o export-sort.x: Machines/export-ring.o +export-msort.x: Machines/export-ring.o export-a2b.x: GC/AtlasSecret.o Machines/SPDZ.o Machines/SPDZ2^64+64.o $(GC_SEMI) $(TINIER) $(EXPORT_VM) GC/Rep4Secret.o GC/Rep4Prep.o $(FHEOFFLINE) export-b2a.x: Machines/export-ring.o diff --git a/Processor/FunctionArgument.cpp b/Processor/FunctionArgument.cpp index d1e1f5685..0a3b173b1 100644 --- a/Processor/FunctionArgument.cpp +++ b/Processor/FunctionArgument.cpp @@ -42,3 +42,8 @@ void FunctionArgument::check_type(const string& type_string) "return type mismatch: " + get_type_string() + "/" + type_string); } + +bool FunctionArgument::has_reg_type(const char* reg_type) +{ + return this->reg_type == string(reg_type); +} diff --git a/Processor/FunctionArgument.h b/Processor/FunctionArgument.h index eaf35c17c..65fa2d508 100644 --- a/Processor/FunctionArgument.h +++ b/Processor/FunctionArgument.h @@ -73,6 +73,15 @@ class FunctionArgument } } + /** + * Argument with integer array. + */ + FunctionArgument(vector& values) : + data(values.data()), size(values.size()), n_bits(0), + reg_type("ci"), memory(false) + { + } + size_t get_size() { return size; @@ -116,6 +125,8 @@ class FunctionArgument return memory; } + bool has_reg_type(const char* reg_type); + template T& get_value(size_t index) { diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 2b31b950f..5c72ffc3d 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -443,7 +443,7 @@ void Machine::run_function(const string& name, arguments[i].get_value>(j).at( k); } - else + else if (arguments[i].has_reg_type("s")) { auto& value = arguments[i].get_value(j); if (arguments[i].get_memory()) @@ -451,11 +451,17 @@ void Machine::run_function(const string& name, else processor.Procp.get_S()[arg_regs.at(i) + j] = value; } + else + { + assert(arguments[i].has_reg_type("ci")); + processor.write_Ci(arg_regs.at(i) + j, arguments[i].get_value(j)); + } } run_tape(0, tape_number, 0, N.num_players()); join_tape(0); + assert(result.has_reg_type("s")); for (size_t j = 0; j < result.get_size(); j++) result.get_value(j) = processor.Procp.get_S()[return_reg + j]; diff --git a/Programs/Source/export-msort.py b/Programs/Source/export-msort.py new file mode 100644 index 000000000..659824439 --- /dev/null +++ b/Programs/Source/export-msort.py @@ -0,0 +1,8 @@ +@export +def sort(x, key_indices): + print_ln('x=%s', x.reveal()) + print_ln('key_indices=%s', key_indices) + res = x.sort(key_indices=key_indices) + print_ln('res=%s', x.reveal()) + +sort(sint.Matrix(500, 2), regint(0, size=1)) diff --git a/Utils/export-msort.cpp b/Utils/export-msort.cpp new file mode 100644 index 000000000..7a1c45cd5 --- /dev/null +++ b/Utils/export-msort.cpp @@ -0,0 +1,59 @@ +/* + * export-sort.cpp + * + */ + +#include "Machines/minimal.hpp" + +int main(int argc, const char** argv) +{ + assert(argc > 1); + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + + typedef Rep3Share2<64> share_type; + Machine machine(N); + + int n = 1000; + + vector inputs; + for (int i = 0; i < n; i++) + { + int j = (i % 2) ? i : (n - i); + inputs.push_back(share_type::constant(j, my_number)); + } + + vector key_indices = {1}; + + vector args = {{inputs, true}, key_indices}; + FunctionArgument res; + + machine.run_function("sort", res, args); + + Opener MC(machine.get_player(), machine.get_sint_mac_key()); + MC.init_open(); + for (auto& x : inputs) + MC.prepare_open(x); + MC.exchange(); + + if (my_number == 0) + { + cout << "res: "; + for (int i = 0; i < 10; i++) + cout << MC.finalize_open() << " "; + cout << endl; + } + else + { + for (int i = 0; i < n; i++) + { + auto x = MC.finalize_open(); + if (not i % 2 and x != i + 2) + { + cerr << "error at " << i << ": " << x << endl; + exit(1); + } + } + } +}