Skip to content

Commit

Permalink
Use ML Server for prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
maxivhuber committed Jan 17, 2025
1 parent 2a38f7e commit a86a923
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rmp-serde = "1.3.0"
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.135"
csv = "1.3.1"
lazy_static = "1.5.0"

[features]
handle-ctrlc = ["ctrlc"]
Expand Down
100 changes: 67 additions & 33 deletions src/ml_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,31 @@ use arboretum_td::heuristic_elimination_order::{
};
use arboretum_td::solver::AtomSolver;
use arboretum_td::{graph::HashMapGraph, io::PaceReader};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::fs;
use std::io::BufWriter;
use std::io::Write;
use std::net::TcpStream;
use std::sync::Mutex;
use std::time::Instant;
use std::{
convert::TryFrom,
fs::File,
io::{self, BufReader, Read},
path::Path,
};
use std::{
io::Write,
process::{Command, Stdio},
};

lazy_static! {
static ref GLOBAL_TCP_STREAM: Mutex<Option<TcpStream>> = Mutex::new(None);
}

const END_MARKER: &[u8] = b"<END>";

fn set_global_connection(address: &str) {
let stream = TcpStream::connect(address).expect("Failed to connect to the server");
let mut global_stream = GLOBAL_TCP_STREAM.lock().unwrap();
*global_stream = Some(stream);
}

pub struct MLSelector {
graph: HashMapGraph,
cache: Vec<i64>,
Expand Down Expand Up @@ -63,6 +73,8 @@ struct Tuple(i64, i64);
pub type MLDecomposer = HeuristicEliminationDecomposer<MLSelector>;

fn main() -> io::Result<()> {
set_global_connection("127.0.0.1:5001");

let file = File::create("output.csv")?;
let buf_writer = BufWriter::new(file);

Expand Down Expand Up @@ -172,42 +184,64 @@ fn main() -> io::Result<()> {
Ok(())
}

fn ml_values(graph: &HashMapGraph, cache: &mut [i64]) -> io::Result<()> {
let mut child = Command::new("uv")
.arg("run")
.arg("--directory")
.arg("/home/mhbr96/Python/tw_bnb/") // change this to correct path
.arg("deserialize_msgpack.py")
.stdin(Stdio::piped()) // write to stdin
.stdout(Stdio::piped()) // read from stdout
.spawn()
.expect("Failed to start Python process");

if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(&graph.serialize())?; // Write serialized graph data to stdin
stdin.flush()?;
}
fn read_until_marker(mut stream: &mut TcpStream) -> Vec<u8> {
let mut reader = BufReader::new(&mut stream);
let mut buffer = Vec::new();
let mut chunk = [0; 4096];

let mut output = Vec::new();
if let Some(ref mut stdout) = child.stdout {
stdout.read_to_end(&mut output)?; // Read all data from stdout
loop {
let bytes_read = reader.read(&mut chunk).unwrap();
if bytes_read == 0 {
break;
}

// Append the read data into the main buffer
buffer.extend_from_slice(&chunk[..bytes_read]);

// Check if the end marker exists in the buffer
if buffer
.windows(END_MARKER.len())
.any(|window| window == END_MARKER)
{
// Remove the end marker from the data
let marker_pos = buffer
.windows(END_MARKER.len())
.position(|window| window == END_MARKER)
.unwrap();
buffer.truncate(marker_pos);
break;
}
}

let results: Vec<Tuple> = rmp_serde::from_slice(&output).expect("Failed to deserialize output");
buffer // Return the message buffer without the end marker
}

let status = child
.wait()
.expect("Failed to wait for Python process to exit");
fn ml_values(graph: &HashMapGraph, cache: &mut [i64]) -> io::Result<()> {
let mut global_stream = GLOBAL_TCP_STREAM.lock().unwrap();
if let Some(ref mut stream) = *global_stream {
let mut serialized_graph = graph.serialize();
serialized_graph.extend_from_slice(END_MARKER);

if status.success() {
for t in results.iter().cloned() {
cache[t.0 as usize] = t.1
stream.write_all(&serialized_graph)?;
stream.flush()?;

let output = read_until_marker(stream);
// let mut output = Vec::new();
// stream.read_to_end(&mut output)?;

let results: Vec<Tuple> = rmp_serde::from_slice(&output).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, "Failed to deserialize output")
})?;

for t in results {
cache[t.0 as usize] = t.1;
}

Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Python program failed",
io::ErrorKind::NotConnected,
"No global TCP stream available",
))
}
}

0 comments on commit a86a923

Please sign in to comment.