Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sqlite search #4

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ async function search(prompt) {
results.innerHTML = "";
for ([path, rank] of json) {
let item = document.createElement("span");
item.appendChild(document.createTextNode(path));
let a = document.createElement("a");
a.href = `files/${path}`;
a.innerHTML = path;
item.appendChild(a);
item.appendChild(document.createElement("br"));
results.appendChild(item);
}
Expand Down
92 changes: 61 additions & 31 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,34 @@ fn entry() -> Result<(), ()> {

let mut skipped = 0;

if use_sqlite_mode {
let index_path = "index.db";

if let Err(err) = fs::remove_file(index_path) {
if err.kind() != std::io::ErrorKind::NotFound {
eprintln!("ERROR: could not delete file {index_path}: {err}");
return Err(())
let (res, time) = measure_time(|| {
if use_sqlite_mode {
let index_path = "index.db";

if let Err(err) = fs::remove_file(index_path) {
if err.kind() != std::io::ErrorKind::NotFound {
eprintln!("ERROR: could not delete file {index_path}: {err}");
return Err(())
}
}

let mut model = SqliteModel::open(Path::new(index_path))?;
model.begin()?;
add_folder_to_model(Path::new(&dir_path), &mut model, &mut skipped)?;
// TODO: implement a special transaction object that implements Drop trait and commits the transaction when it goes out of scope
model.commit()?;
} else {
let index_path = "index.json";
let mut model = Default::default();
add_folder_to_model(Path::new(&dir_path), &mut model, &mut skipped)?;
save_model_as_json(&model, index_path)?;
}
Ok(())
});

let mut model = SqliteModel::open(Path::new(index_path))?;
model.begin()?;
add_folder_to_model(Path::new(&dir_path), &mut model, &mut skipped)?;
// TODO: implement a special transaction object that implements Drop trait and commits the transaction when it goes out of scope
model.commit()?;
} else {
let index_path = "index.json";
let mut model = Default::default();
add_folder_to_model(Path::new(&dir_path), &mut model, &mut skipped)?;
save_model_as_json(&model, index_path)?;
}
eprintln!("Indexed in {time} s");

res?;

println!("Skipped {skipped} files.");
Ok(())
Expand All @@ -190,11 +197,18 @@ fn entry() -> Result<(), ()> {
})?.chars().collect::<Vec<_>>();

if use_sqlite_mode {
let model = SqliteModel::open(Path::new(&index_path))?;
let (res, time) = measure_time(|| {
let model = SqliteModel::open(Path::new(&index_path))?;

for (path, rank) in model.search_query(&prompt)?.iter().take(20) {
println!("{path} {rank}", path = path.display());
}
for (path, rank) in model.search_query(&prompt)?.iter().take(20) {
println!("{path} {rank}", path = path.display());
}
Ok(())
});

res?;

eprintln!("Sqlite search time: {time} s");
} else {
let index_file = File::open(&index_path).map_err(|err| {
eprintln!("ERROR: could not open index file {index_path}: {err}");
Expand All @@ -220,19 +234,29 @@ fn entry() -> Result<(), ()> {
let address = args.next().unwrap_or("127.0.0.1:6969".to_string());

if use_sqlite_mode {
let model = SqliteModel::open(Path::new(&index_path))?;
let (res, time) = measure_time(|| {
SqliteModel::open(Path::new(&index_path))
});

eprintln!("Load index file (sqlite): {time} s");

server::start(&address, &model)
server::start(&address, &res?)
} else {
let index_file = File::open(&index_path).map_err(|err| {
eprintln!("ERROR: could not open index file {index_path}: {err}");
})?;
let (res, time) = measure_time(|| {
let index_file = File::open(&index_path).map_err(|err| {
eprintln!("ERROR: could not open index file {index_path}: {err}");
})?;

let model: InMemoryModel = serde_json::from_reader(index_file).map_err(|err| {
eprintln!("ERROR: could not parse index file {index_path}: {err}");
})?;
let model: InMemoryModel = serde_json::from_reader(index_file).map_err(|err| {
eprintln!("ERROR: could not parse index file {index_path}: {err}");
})?;

Ok(model)
});

eprintln!("Load index file: {time} s");

server::start(&address, &model)
server::start(&address, &res?)
}
}

Expand All @@ -250,3 +274,9 @@ fn main() -> ExitCode {
Err(()) => ExitCode::FAILURE,
}
}

fn measure_time<T>(f: impl FnOnce() -> T) -> (T, f64) {
let start = std::time::Instant::now();
let ret = f();
(ret, start.elapsed().as_secs_f64())
}
65 changes: 63 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ impl SqliteModel {
self.execute("COMMIT;")
}

pub fn query_tf(&self, term: &str) -> Result<i64, ()> {
let query = "SELECT freq FROM DocFreq WHERE term = :term";
let log_err = |err| {
eprintln!("ERROR: Could not execute query {query}: {err}");
};
let mut stmt = self.connection.prepare(query).map_err(log_err)?;
stmt.bind_iter::<_, (_, sqlite::Value)>([
(":term", term.into()),
]).map_err(log_err)?;
Ok(match stmt.next().map_err(log_err)? {
sqlite::State::Row => stmt.read::<i64, _>("freq").map_err(log_err)?,
sqlite::State::Done => 0
})
}

pub fn open(path: &Path) -> Result<Self, ()> {
let connection = sqlite::open(path).map_err(|err| {
eprintln!("ERROR: could not open sqlite database {path}: {err}", path = path.display());
Expand Down Expand Up @@ -67,8 +82,54 @@ impl SqliteModel {
}

impl Model for SqliteModel {
fn search_query(&self, _query: &[char]) -> Result<Vec<(PathBuf, f32)>, ()> {
todo!()
fn search_query(&self, query: &[char]) -> Result<Vec<(PathBuf, f32)>, ()> {
let mut result = Vec::new();

let log_err = |err| {
eprintln!("ERROR: Could not execute query {query:?}: {err}");
};

let token_freqs = Lexer::new(&query).map(|term| {
let freq = self.query_tf(&term)? as usize;
Ok((term, freq))
}).collect::<Result<DocFreq, _>>()?;

let mut num_documents = None;
self.connection.iterate("SELECT count(*) from Documents", |pairs| {
num_documents = pairs[0].1.unwrap().parse::<usize>().ok();
true
}).map_err(log_err)?;
let num_documents = num_documents.ok_or(())?;

self.connection.iterate("SELECT id, path, term_count FROM Documents", |row| {
let Ok(id) = row[0].1.unwrap().parse::<i32>() else { return true };
let path = row[1].1.unwrap();
let term_count = row[2].1.unwrap().parse().expect("term_count should be an integer");
let mut rank = 0f32;
for token in token_freqs.keys() {
let mut doc = Doc::default();
doc.count = term_count;
let Ok(mut stmt) = self.connection.prepare("SELECT doc_id, freq FROM TermFreq WHERE term = :token AND doc_id = :id")
.map_err(log_err) else { return false };
let Ok(_) = stmt.bind_iter::<_, (_, sqlite::Value)>([
(":token", token.as_str().into()),
(":id", (id as i64).into()),
]).map_err(log_err) else {return false};
if let Ok(_) = stmt.next() {
let Ok(freq) = stmt.read::<i64, _>("freq").map_err(log_err) else {return false};
doc.tf.insert(token.to_owned(), freq as usize);
}

rank += compute_tf(token, &doc) * compute_idf(&token, num_documents, &token_freqs);
}
if rank.is_finite() {
result.push((PathBuf::from(path), rank));
}
true
}).map_err(log_err)?;
result.sort_by(|(_, rank1), (_, rank2)| rank1.partial_cmp(rank2).expect("Rank should be comparable"));
result.reverse();
Ok(result)
}

fn add_document(&mut self, path: PathBuf, content: &[char]) -> Result<(), ()> {
Expand Down
16 changes: 15 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::fs::File;
use std::str;
use std::io;

use crate::measure_time;

use super::model::*;

use tiny_http::{Server, Request, Response, Header, Method, StatusCode};
Expand Down Expand Up @@ -76,14 +78,26 @@ fn serve_request(model: &impl Model, request: Request) -> io::Result<()> {

match (request.method(), request.url()) {
(Method::Post, "/api/search") => {
serve_api_search(model, request)
let (res, time) = measure_time(|| serve_api_search(model, request));
eprintln!("Search in {time}");
res
}
(Method::Get, "/index.js") => {
serve_static_file(request, "index.js", "text/javascript; charset=utf-8")
}
(Method::Get, "/") | (Method::Get, "/index.html") => {
serve_static_file(request, "index.html", "text/html; charset=utf-8")
}
(Method::Get, path) => {
if path.starts_with("/files/") {
// Make an owned copy to avoid borrow error
// TODO: sanitize URL to avoid host file system with "../", although the application shouldn't be used on the web
let rest = path["/files/".len()..].to_owned();
serve_static_file(request, &rest, "text/xml")
} else {
serve_404(request)
}
}
_ => {
serve_404(request)
}
Expand Down