diff --git a/MODULE.bazel b/MODULE.bazel index 58faa0d..ee63456 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2") bazel_dep(name = "highway", version = "1.1.0") bazel_dep(name = "nlohmann_json", version = "3.11.3") bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "pybind11_bazel", version = "2.12.0") bazel_dep(name = "rules_cc", version = "0.0.9") bazel_dep(name = "rules_license", version = "0.0.7") bazel_dep(name = "google_benchmark", version = "1.8.5") diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 89e2222..6b451bd 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -1,5 +1,5 @@ -load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc") # [internal] load strict.bzl +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") package( default_applicable_licenses = [ @@ -12,8 +12,9 @@ cc_library( name = "compression_clif_aux", srcs = ["compression_clif_aux.cc"], hdrs = ["compression_clif_aux.h"], + visibility = ["//visibility:private"], deps = [ - "//third_party/absl/types:span", + "@abseil-cpp//absl/types:span", "//compression:compress", "//compression:io", "@highway//:hwy", @@ -21,12 +22,12 @@ cc_library( ], ) -py_clif_cc( +pybind_extension( name = "compression", - srcs = ["compression.clif"], + srcs = ["compression_extension.cc"], deps = [ ":compression_clif_aux", - "//third_party/absl/python/numpy:span_clif_lib", + "@abseil-cpp//absl/types:span", ], ) diff --git a/compression/python/compression.clif b/compression/python/compression.clif deleted file mode 100644 index 69dfc9b..0000000 --- a/compression/python/compression.clif +++ /dev/null @@ -1,14 +0,0 @@ -from "third_party/absl/python/numpy/span.h" import * -from "third_party/gemma_cpp/compression/python/compression_clif_aux.h": - namespace `gcpp`: - class SbsWriter: - # NOTE: Individual compression backends may impose constraints on the - # array length, such as a minimum of (say) 32 elements. - def `Insert` as insert(self, name: str, weights: NumpyArray) - def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray) - def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray) - def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray) - - def `AddScales` as add_scales(self, scales: list) - - def `Write` as write(self, path: str) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index a9d3894..ba91781 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -20,7 +20,7 @@ #ifndef GEMMA_ONCE #define GEMMA_ONCE -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" #include "compression/io.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 8dc7a9d..fd4efc8 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -5,7 +5,7 @@ #include #include -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" namespace gcpp { diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc new file mode 100644 index 0000000..c2916a8 --- /dev/null +++ b/compression/python/compression_extension.cc @@ -0,0 +1,38 @@ +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "compression/python/compression_clif_aux.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +using gcpp::SbsWriter; + +namespace py = pybind11; + +namespace { +template +void wrap_span(SbsWriter& writer, std::string name, py::array_t data) { + if (data.ndim() != 1 || data.strides(0) != sizeof(float)) { + throw std::domain_error("Input array must be 1D and densely packed."); + } + std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size())); +} +} // namespace + +PYBIND11_MODULE(compression, m) { + py::class_(m, "SbsWriter") + .def(py::init<>()) + // NOTE: Individual compression backends may impose constraints on the + // array length, such as a minimum of (say) 32 elements. + .def("insert", wrap_span<&SbsWriter::Insert>) + .def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>) + .def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>) + .def("insert_float", wrap_span<&SbsWriter::InsertFloat>) + .def("add_scales", &SbsWriter::AddScales) + .def("write", &SbsWriter::Write); +}