Skip to content

Commit

Permalink
chore(gpu): restrict bindings generation
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Oct 31, 2024
1 parent fc26f2a commit 3bd7cf7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 24 deletions.
72 changes: 48 additions & 24 deletions backends/tfhe-cuda-backend/build.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::env;
use std::path::PathBuf;
use std::process::Command;

fn main() {
if let Ok(val) = env::var("DOCS_RS") {
if let Ok(val) = std::env::var("DOCS_RS") {
if val.parse::<u32>() == Ok(1) {
return;
}
Expand All @@ -28,7 +27,7 @@ fn main() {
println!("cargo::rerun-if-changed=cuda/CMakeLists.txt");
println!("cargo::rerun-if-changed=src");

if env::consts::OS == "linux" {
if std::env::consts::OS == "linux" {
let output = Command::new("./get_os_name.sh").output().unwrap();
let distribution = String::from_utf8(output.stdout).unwrap();
if distribution != "Ubuntu\n" {
Expand Down Expand Up @@ -56,30 +55,55 @@ fn main() {
println!("cargo:rustc-link-lib=stdc++");

let header_path = "wrapper.h";
println!("cargo:rerun-if-changed={}", header_path);

let headers = vec![
"wrapper.h",
"cuda/include/ciphertext.h",
"cuda/include/integer/compression/compression.h",
"cuda/include/integer/integer.h",
"cuda/include/keyswitch.h",
"cuda/include/linear_algebra.h",
"cuda/include/pbs/programmable_bootstrap.h",
"cuda/include/pbs/programmable_bootstrap_multibit.h",
];
let out_path = PathBuf::from("src").join("bindings.rs");
let bindings_modified = if out_path.exists() {
std::fs::metadata(&out_path).unwrap().modified().unwrap()
} else {
std::time::SystemTime::UNIX_EPOCH // If bindings file doesn't exist, consider it older
};
let mut headers_modified = bindings_modified;
for header in headers {
println!("cargo:rerun-if-changed={}", header);
// Check modification times
let header_modified = std::fs::metadata(header).unwrap().modified().unwrap();
if header_modified > headers_modified {
headers_modified = header_modified;
}
}

let bindings = bindgen::Builder::default()
.header(header_path)
// allow only what we are interested in, the custom types appearing in the interface
.allowlist_type("PBS_TYPE")
.allowlist_type("SHIFT_OR_ROTATE_TYPE")
// and the functions reachable from the headers included in wrapper.h
.allowlist_function(".*")
.clang_arg("-x")
.clang_arg("c++")
.clang_arg("-std=c++17")
.clang_arg("-I/usr/include")
.clang_arg("-I/usr/local/include")
.ctypes_prefix("ffi")
.raw_line("use crate::ffi;")
.generate()
.expect("Unable to generate bindings");
// Regenerate bindings only if header has been modified
if headers_modified > bindings_modified {
let bindings = bindgen::Builder::default()
.header(header_path)
// allow only what we are interested in, the custom types appearing in the interface
.allowlist_type("PBS_TYPE")
.allowlist_type("SHIFT_OR_ROTATE_TYPE")
// and the functions reachable from the headers included in wrapper.h
.allowlist_function(".*")
.clang_arg("-x")
.clang_arg("c++")
.clang_arg("-std=c++17")
.clang_arg("-I/usr/include")
.clang_arg("-I/usr/local/include")
.ctypes_prefix("ffi")
.raw_line("use crate::ffi;")
.generate()
.expect("Unable to generate bindings");

bindings
.write_to_file(&out_path)
.expect("Couldn't write bindings!");
bindings
.write_to_file(&out_path)
.expect("Couldn't write bindings!");
}
} else {
panic!(
"Error: platform not supported, tfhe-cuda-backend not built (only Linux is supported)"
Expand Down
1 change: 1 addition & 0 deletions backends/tfhe-cuda-backend/wrapper.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// These files must be added to the headers vec in the build.rs to check for file changes
#include "cuda/include/ciphertext.h"
#include "cuda/include/integer/compression/compression.h"
#include "cuda/include/integer/integer.h"
Expand Down

0 comments on commit 3bd7cf7

Please sign in to comment.