diff --git a/bindings/python/vectorlite_py/test/vectorlite_test.py b/bindings/python/vectorlite_py/test/vectorlite_test.py index 4007f1f..3df96c0 100644 --- a/bindings/python/vectorlite_py/test/vectorlite_test.py +++ b/bindings/python/vectorlite_py/test/vectorlite_test.py @@ -121,59 +121,60 @@ def remove_quote(s: str): file_path = os.path.join(tempdir, 'index.bin') file_paths = [f'\"{file_path}\"', f'\'{file_path}\''] - for index_file_path in file_paths: - assert not os.path.exists(remove_quote(index_file_path)) - - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table using vectorlite(my_embedding float32[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') - - for i in range(NUM_ELEMENTS): - cur.execute('insert into my_table (rowid, my_embedding) values (?, ?)', (i, random_vectors[i].tobytes())) - - result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - - conn.close() - # The index file should be created - index_file_size = os.path.getsize(remove_quote(index_file_path)) - assert os.path.exists(remove_quote(index_file_path)) and index_file_size > 0 - - # test if the index file could be loaded with the same parameters without inserting data again - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table using vectorlite(my_embedding float32[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') - result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - conn.close() - # The index file should be created - assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size - - # test if the index file could be loaded with different hnsw parameters and distance type without inserting data again - # But hnsw parameters can't be changed even if different values are set, they will be owverwritten by the value from the index file - # todo: test whether hnsw parameters are overwritten after more functions are introduced to provide runtime stats. - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding float32[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=32,M=32), {index_file_path})') - result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - - # test searching with ef_search = 30, which defaults to 10 - result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?, ?))', (random_vectors[0].tobytes(), 10, 30)).fetchall() - assert len(result) == 10 - conn.close() - assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size - - - # test if `drop table` deletes the index file - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding float32[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=64,M=32), {index_file_path})') - result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - - cur.execute(f'drop table my_table2') - assert not os.path.exists(remove_quote(index_file_path)) - conn.close() - - + for vector_type in ['float32', 'bfloat16']: + for index_file_path in file_paths: + assert not os.path.exists(remove_quote(index_file_path)) + + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table using vectorlite(my_embedding {vector_type}[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') + + for i in range(NUM_ELEMENTS): + cur.execute('insert into my_table (rowid, my_embedding) values (?, ?)', (i, random_vectors[i].tobytes())) + + result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + + conn.close() + # The index file should be created + index_file_size = os.path.getsize(remove_quote(index_file_path)) + assert os.path.exists(remove_quote(index_file_path)) and index_file_size > 0 + + # test if the index file could be loaded with the same parameters without inserting data again + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table using vectorlite(my_embedding {vector_type}[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') + result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + conn.close() + # The index file should be created + assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size + + # test if the index file could be loaded with different hnsw parameters and distance type without inserting data again + # But hnsw parameters can't be changed even if different values are set, they will be owverwritten by the value from the index file + # todo: test whether hnsw parameters are overwritten after more functions are introduced to provide runtime stats. + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding {vector_type}[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=32,M=32), {index_file_path})') + result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + + # test searching with ef_search = 30, which defaults to 10 + result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?, ?))', (random_vectors[0].tobytes(), 10, 30)).fetchall() + assert len(result) == 10 + conn.close() + assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size + + + # test if `drop table` deletes the index file + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding {vector_type}[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=64,M=32), {index_file_path})') + result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + + cur.execute(f'drop table my_table2') + assert not os.path.exists(remove_quote(index_file_path)) + conn.close() + + diff --git a/format.sh b/format.sh index 69966de..06830aa 100644 --- a/format.sh +++ b/format.sh @@ -1 +1 @@ -clang-format -style=file -i src/*.h src/*.cpp \ No newline at end of file +clang-format -style=file -i vectorlite/*.h vectorlite/*.cpp \ No newline at end of file diff --git a/vectorlite/CMakeLists.txt b/vectorlite/CMakeLists.txt index c9e47aa..9ebb666 100644 --- a/vectorlite/CMakeLists.txt +++ b/vectorlite/CMakeLists.txt @@ -10,7 +10,7 @@ message(STATUS "Compiling on ${CMAKE_SYSTEM_PROCESSOR}") add_subdirectory(ops) -add_library(vectorlite SHARED vectorlite.cpp virtual_table.cpp vector.cpp vector_view.cpp util.cpp vector_space.cpp index_options.cpp sqlite_functions.cpp constraint.cpp) +add_library(vectorlite SHARED vectorlite.cpp virtual_table.cpp util.cpp vector_space.cpp index_options.cpp sqlite_functions.cpp constraint.cpp quantization.cpp) # remove the lib prefix to make the shared library name consistent on all platforms. set_target_properties(vectorlite PROPERTIES PREFIX "") target_include_directories(vectorlite PUBLIC ${RAPIDJSON_INCLUDE_DIRS} ${HNSWLIB_INCLUDE_DIRS} ${PROJECT_BINARY_DIR}) diff --git a/vectorlite/constraint.cpp b/vectorlite/constraint.cpp index 0357938..6e7acde 100644 --- a/vectorlite/constraint.cpp +++ b/vectorlite/constraint.cpp @@ -13,6 +13,7 @@ #include "absl/strings/str_join.h" #include "hnswlib/hnswlib.h" #include "macros.h" +#include "quantization.h" #include "sqlite3ext.h" #include "util.h" #include "vector.h" @@ -195,20 +196,39 @@ absl::StatusOr QueryExecutor::Execute() const { index_.setEf(*knn_param->ef_search); } try { - if (!space_.normalize) { - return index_.searchKnnCloserFirst( - knn_param->query_vector.data().data(), knn_param->k, - rowid_filter.get()); + if (space_.vector_type == VectorType::Float32) { + if (!space_.normalize) { + return index_.searchKnnCloserFirst( + knn_param->query_vector.data().data(), knn_param->k, + rowid_filter.get()); + } + + VECTORLITE_ASSERT(space_.normalize); + // Copy the query vector and normalize it. + Vector normalized_vector = Vector::Normalize(knn_param->query_vector); + + auto result = index_.searchKnnCloserFirst( + normalized_vector.data().data(), knn_param->k, rowid_filter.get()); + return result; + } else if (space_.vector_type == VectorType::BFloat16) { + BF16Vector quantized_vector = Quantize(knn_param->query_vector); + + if (!space_.normalize) { + return index_.searchKnnCloserFirst(quantized_vector.data().data(), + knn_param->k, rowid_filter.get()); + } + + VECTORLITE_ASSERT(space_.normalize); + BF16Vector normalized_vector = quantized_vector.Normalize(); + + auto result = index_.searchKnnCloserFirst( + normalized_vector.data().data(), knn_param->k, rowid_filter.get()); + return result; + } else { + return absl::InternalError( + absl::StrFormat("Unknown vector type: %d", space_.vector_type)); } - VECTORLITE_ASSERT(space_.normalize); - // Copy the query vector and normalize it. - Vector normalized_vector = Vector::Normalize(knn_param->query_vector); - - auto result = index_.searchKnnCloserFirst( - normalized_vector.data().data(), knn_param->k, rowid_filter.get()); - return result; - } catch (const std::runtime_error& e) { return absl::InternalError(e.what()); } diff --git a/vectorlite/constraint.h b/vectorlite/constraint.h index 90af01c..014217f 100644 --- a/vectorlite/constraint.h +++ b/vectorlite/constraint.h @@ -12,8 +12,8 @@ #include "hnswlib/hnswlib.h" #include "macros.h" #include "sqlite3.h" -#include "vector_view.h" #include "vector_space.h" +#include "vector_view.h" namespace vectorlite { diff --git a/vectorlite/distance.h b/vectorlite/distance.h index d431e7b..4e68d79 100644 --- a/vectorlite/distance.h +++ b/vectorlite/distance.h @@ -1,6 +1,8 @@ #pragma once #include "hnswlib/hnswlib.h" +#include "hwy/base.h" +#include "macros.h" #include "ops/ops.h" // This file implements hnswlib::SpaceInterface using vectorlite @@ -9,12 +11,13 @@ // PC(i5-12600KF with AVX2 support) namespace vectorlite { -class InnerProductSpace : public hnswlib::SpaceInterface { +template +class GenericInnerProductSpace : public hnswlib::SpaceInterface { public: - explicit InnerProductSpace(size_t dim) - : dim_(dim), func_(InnerProductSpace::InnerProductDistanceFunc) {} + explicit GenericInnerProductSpace(size_t dim) + : dim_(dim), func_(GenericInnerProductSpace::InnerProductDistanceFunc) {} - size_t get_data_size() override { return dim_ * sizeof(float); } + size_t get_data_size() override { return dim_ * sizeof(T); } void* get_dist_func_param() override { return &dim_; } @@ -26,18 +29,22 @@ class InnerProductSpace : public hnswlib::SpaceInterface { static float InnerProductDistanceFunc(const void* v1, const void* v2, const void* dim) { - return ops::InnerProductDistance(static_cast(v1), - static_cast(v2), + return ops::InnerProductDistance(static_cast(v1), + static_cast(v2), *reinterpret_cast(dim)); } }; -class L2Space : public hnswlib::SpaceInterface { +using InnerProductSpace = GenericInnerProductSpace; +using InnerProductSpaceBF16 = GenericInnerProductSpace; + +template +class GenericL2Space : public hnswlib::SpaceInterface { public: - explicit L2Space(size_t dim) - : dim_(dim), func_(L2Space::L2DistanceSquaredFunc) {} + explicit GenericL2Space(size_t dim) + : dim_(dim), func_(GenericL2Space::L2DistanceSquaredFunc) {} - size_t get_data_size() override { return dim_ * sizeof(float); } + size_t get_data_size() override { return dim_ * sizeof(T); } void* get_dist_func_param() override { return &dim_; } @@ -49,10 +56,13 @@ class L2Space : public hnswlib::SpaceInterface { static float L2DistanceSquaredFunc(const void* v1, const void* v2, const void* dim) { - return ops::L2DistanceSquared(static_cast(v1), - static_cast(v2), + return ops::L2DistanceSquared(static_cast(v1), + static_cast(v2), *reinterpret_cast(dim)); } }; +using L2Space = GenericL2Space; +using L2SpaceBF16 = GenericL2Space; + } // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/macros.h b/vectorlite/macros.h index 5559e80..81b1b67 100644 --- a/vectorlite/macros.h +++ b/vectorlite/macros.h @@ -1,5 +1,9 @@ #pragma once +#include + +#include "hwy/base.h" + #if defined(_WIN32) || defined(__WIN32__) #define VECTORLITE_EXPORT __declspec(dllexport) #else @@ -11,3 +15,11 @@ #include #define VECTORLITE_ASSERT(x) assert(x) #endif + +#define VECTORLITE_IF_FLOAT_SUPPORTED(T) \ + std::enable_if_t || \ + std::is_same_v>* = nullptr + +#define VECTORLITE_IF_FLOAT_SUPPORTED_FWD_DECL(T) \ + std::enable_if_t || \ + std::is_same_v>* diff --git a/vectorlite/ops/ops.h b/vectorlite/ops/ops.h index dd7c4ba..9ca76fe 100644 --- a/vectorlite/ops/ops.h +++ b/vectorlite/ops/ops.h @@ -16,8 +16,6 @@ namespace vectorlite { namespace ops { -using DistanceFunc = float (*)(const float*, const float*, size_t); - // v1 and v2 MUST not be nullptr but can point to the same array. HWY_DLLEXPORT float InnerProduct(const float* v1, const float* v2, size_t num_elements); diff --git a/vectorlite/util.h b/vectorlite/util.h index e56c359..1e126a0 100644 --- a/vectorlite/util.h +++ b/vectorlite/util.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "hnswlib/hnswlib.h" @@ -25,4 +26,84 @@ std::optional DetectSIMD(); bool IsRowidInIndex(const hnswlib::HierarchicalNSW& index, hnswlib::labeltype rowid); +// Below *Base classes are taken from +// https://github.com/abseil/abseil-cpp/blob/20240722.0/absl/status/internal/statusor_internal.h#L368 +// to allow implicitly deleted constructors and assignment +// operators in a Derived class. For example, `CopyCtorBase` will explicitly +// delete the copy constructor when T is not copy constructible and `Derived` +// class will inherit that behavior implicitly. +template ::value> +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = default; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = delete; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template ::value> +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = default; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = delete; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template ::value&& + std::is_copy_assignable::value> +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = default; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = delete; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template ::value&& + std::is_move_assignable::value> +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = default; +}; + +template +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = delete; +}; + } // end namespace vectorlite diff --git a/vectorlite/vector.cpp b/vectorlite/vector.cpp deleted file mode 100644 index a2a4808..0000000 --- a/vectorlite/vector.cpp +++ /dev/null @@ -1,115 +0,0 @@ -#include "vector.h" - -#include -#include - -#include -#include -#include - -#include "hnswlib/hnswlib.h" -#include "hnswlib/space_l2.h" -#include "macros.h" -#include "rapidjson/document.h" -#include "rapidjson/error/en.h" -#include "rapidjson/stringbuffer.h" -#include "rapidjson/writer.h" -#include "vector_space.h" -#include "vector_view.h" -#include "ops/ops.h" - -namespace vectorlite { - -absl::StatusOr Vector::FromJSON(std::string_view json) { - rapidjson::Document doc; - doc.Parse(json.data(), json.size()); - auto err = doc.GetParseError(); - if (err != rapidjson::ParseErrorCode::kParseErrorNone) { - return absl::InvalidArgumentError(rapidjson::GetParseError_En(err)); - } - - Vector result; - - if (doc.IsArray()) { - for (auto& v : doc.GetArray()) { - if (v.IsNumber()) { - result.data_.push_back(v.GetFloat()); - } else { - return absl::InvalidArgumentError( - "JSON array contains non-numeric value."); - } - } - return result; - } - - return absl::InvalidArgumentError("Input JSON is not an array."); -} - -absl::StatusOr Vector::FromBlob(std::string_view blob) { - auto vector_view = VectorView::FromBlob(blob); - if (vector_view.ok()) { - return Vector(*vector_view); - } - return vector_view.status(); -} - -std::string Vector::ToJSON() const { - VectorView vector_view(*this); - - return vector_view.ToJSON(); -} - -absl::StatusOr Distance(VectorView v1, VectorView v2, - DistanceType distance_type) { - if (v1.dim() != v2.dim()) { - std::string err = - absl::StrFormat("Dimension mismatch: %d != %d", v1.dim(), v2.dim()); - return absl::InvalidArgumentError(err); - } - - ops::DistanceFunc distance_func = nullptr; - - switch (distance_type) { - case DistanceType::L2: - distance_func = ops::L2DistanceSquared; - break; - case DistanceType::InnerProduct: - distance_func = ops::InnerProductDistance; - break; - case DistanceType::Cosine: - distance_func = ops::InnerProductDistance; - break; - default: - return absl::InvalidArgumentError("Invalid distance type"); - } - - bool normalize = distance_type == DistanceType::Cosine; - - if (!normalize) { - return distance_func(v1.data().data(), v2.data().data(), v1.dim()); - } - - Vector lhs = Vector::Normalize(v1); - Vector rhs = Vector::Normalize(v2); - return distance_func(lhs.data().data(), rhs.data().data(), v1.dim()); -} - -std::string_view Vector::ToBlob() const { - VectorView vector_view(*this); - - return vector_view.ToBlob(); -} - -Vector Vector::Normalize() const { - VectorView vector_view(*this); - - return Vector::Normalize(vector_view); -} - -Vector Vector::Normalize(VectorView vector_view) { - Vector v(vector_view); - ops::Normalize(v.data_.data(), vector_view.dim()); - return v; -} - -} // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/vector.h b/vectorlite/vector.h index d833335..fb78204 100644 --- a/vectorlite/vector.h +++ b/vectorlite/vector.h @@ -1,52 +1,148 @@ #pragma once +#include + #include #include #include #include "absl/status/statusor.h" #include "macros.h" +#include "ops/ops.h" +#include "rapidjson/document.h" +#include "rapidjson/error/en.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/writer.h" +#include "util.h" #include "vector_space.h" #include "vector_view.h" namespace vectorlite { -class Vector { +template +class GenericVector : private CopyAssignBase, + private CopyCtorBase, + private MoveCtorBase, + private MoveAssignBase { public: - Vector() = default; - Vector(const Vector&) = default; - Vector(Vector&&) = default; + GenericVector() = default; + GenericVector(const GenericVector&) = default; + GenericVector(GenericVector&&) = default; - explicit Vector(std::vector&& data) : data_(std::move(data)) {} - explicit Vector(const std::vector& data) : data_(data) {} - explicit Vector(VectorView vector_view) + explicit GenericVector(std::vector&& data) : data_(std::move(data)) {} + explicit GenericVector(const std::vector& data) : data_(data) {} + explicit GenericVector(GenericVectorView vector_view) : data_(vector_view.data().begin(), vector_view.data().end()) {} - Vector& operator=(const Vector&) = default; - Vector& operator=(Vector&&) = default; + GenericVector& operator=(const GenericVector&) = default; + GenericVector& operator=(GenericVector&&) = default; + + static absl::StatusOr> FromJSON(std::string_view json) { + rapidjson::Document doc; + doc.Parse(json.data(), json.size()); + auto err = doc.GetParseError(); + if (err != rapidjson::ParseErrorCode::kParseErrorNone) { + return absl::InvalidArgumentError(rapidjson::GetParseError_En(err)); + } + + GenericVector result; + + if (doc.IsArray()) { + for (auto& v : doc.GetArray()) { + if (v.IsNumber()) { + result.data_.push_back(hwy::ConvertScalarTo(v.GetFloat())); + } else { + return absl::InvalidArgumentError( + "JSON array contains non-numeric value."); + } + } + return result; + } - static absl::StatusOr FromJSON(std::string_view json); + return absl::InvalidArgumentError("Input JSON is not an array."); + } - static absl::StatusOr FromBlob(std::string_view blob); + static absl::StatusOr> FromBlob(std::string_view blob) { + auto vector_view = GenericVectorView::FromBlob(blob); + if (vector_view.ok()) { + return GenericVector(*vector_view); + } + return vector_view.status(); + } - std::string ToJSON() const; + std::string ToJSON() const { + GenericVectorView vector_view(*this); - std::string_view ToBlob() const; + return vector_view.ToJSON(); + } - const std::vector& data() const { return data_; } + std::string_view ToBlob() const { + GenericVectorView vector_view(*this); + + return vector_view.ToBlob(); + }; + + const std::vector& data() const { return data_; } std::size_t dim() const { return data_.size(); } - Vector Normalize() const; + GenericVector Normalize() const { + GenericVectorView vector_view(*this); - static Vector Normalize(VectorView vector_view); + return GenericVector::Normalize(vector_view); + } + + static GenericVector Normalize(GenericVectorView vector_view) { + GenericVector v(vector_view); + ops::Normalize(v.data_.data(), vector_view.dim()); + return v; + } private: - std::vector data_; + std::vector data_; }; +template +using DistanceFunc = float (*)(const T*, const T*, size_t); + // Calculate the distance between two vectors. -absl::StatusOr Distance(VectorView v1, VectorView v2, - DistanceType space); +template +absl::StatusOr Distance(GenericVectorView v1, GenericVectorView v2, + DistanceType distance_type) { + if (v1.dim() != v2.dim()) { + std::string err = + absl::StrFormat("Dimension mismatch: %d != %d", v1.dim(), v2.dim()); + return absl::InvalidArgumentError(err); + } + + DistanceFunc distance_func = nullptr; + + switch (distance_type) { + case DistanceType::L2: + distance_func = ops::L2DistanceSquared; + break; + case DistanceType::InnerProduct: + distance_func = ops::InnerProductDistance; + break; + case DistanceType::Cosine: + distance_func = ops::InnerProductDistance; + break; + default: + return absl::InvalidArgumentError("Invalid distance type"); + } + + bool normalize = distance_type == DistanceType::Cosine; + + if (!normalize) { + return distance_func(v1.data().data(), v2.data().data(), v1.dim()); + } + + GenericVector lhs = GenericVector::Normalize(v1); + GenericVector rhs = GenericVector::Normalize(v2); + return distance_func(lhs.data().data(), rhs.data().data(), v1.dim()); +} + +using Vector = GenericVector; +using BF16Vector = GenericVector; } // namespace vectorlite diff --git a/vectorlite/vector_space.cpp b/vectorlite/vector_space.cpp index c9b1226..bd69e56 100644 --- a/vectorlite/vector_space.cpp +++ b/vectorlite/vector_space.cpp @@ -1,8 +1,12 @@ #include "vector_space.h" +#include + +#include #include #include +#include "absl/base/optimization.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "distance.h" @@ -27,9 +31,42 @@ std::optional ParseVectorType(std::string_view vector_type) { if (vector_type == "float32") { return VectorType::Float32; } + + if (vector_type == "bfloat16") { + return VectorType::BFloat16; + } + return std::nullopt; } +static std::unique_ptr> CreateL2Space( + size_t dim, VectorType vector_type) { + switch (vector_type) { + case VectorType::Float32: + return std::make_unique(dim); + case VectorType::BFloat16: + return std::make_unique(dim); + default: + // This should never happen, but we include it for completeness + ABSL_UNREACHABLE(); + return nullptr; + } +} + +static std::unique_ptr> CreateInnerProductSpace( + size_t dim, VectorType vector_type) { + switch (vector_type) { + case VectorType::Float32: + return std::make_unique(dim); + case VectorType::BFloat16: + return std::make_unique(dim); + default: + // This should never happen, but we include it for completeness + ABSL_UNREACHABLE(); + return nullptr; + } +} + absl::StatusOr VectorSpace::Create(size_t dim, DistanceType distance_type, VectorType vector_type) { @@ -43,13 +80,13 @@ absl::StatusOr VectorSpace::Create(size_t dim, result.vector_type = vector_type; switch (distance_type) { case DistanceType::L2: - result.space = std::make_unique(dim); + result.space = CreateL2Space(dim, vector_type); break; case DistanceType::InnerProduct: - result.space = std::make_unique(dim); + result.space = CreateInnerProductSpace(dim, vector_type); break; case DistanceType::Cosine: - result.space = std::make_unique(dim); + result.space = CreateInnerProductSpace(dim, vector_type); break; default: std::string err_msg = diff --git a/vectorlite/vector_space.h b/vectorlite/vector_space.h index c8df5a7..cef25e4 100644 --- a/vectorlite/vector_space.h +++ b/vectorlite/vector_space.h @@ -19,6 +19,7 @@ std::optional ParseDistanceType(std::string_view distance_type); enum class VectorType { Float32, + BFloat16, }; std::optional ParseVectorType(std::string_view vector_type); diff --git a/vectorlite/vector_space_test.cpp b/vectorlite/vector_space_test.cpp index 065008d..c3f0a90 100644 --- a/vectorlite/vector_space_test.cpp +++ b/vectorlite/vector_space_test.cpp @@ -1,5 +1,6 @@ #include "vector_space.h" +#include "absl/strings/str_format.h" #include "gtest/gtest.h" TEST(ParseDistanceType, ShouldSupport_L2_InnerProduct_Cosine) { @@ -35,91 +36,114 @@ TEST(ParseVectorType, ShouldReturnNullOptForInvalidVectorType) { EXPECT_FALSE(uint8); } +TEST(ParseVectorType, ShouldSupportBFloat16) { + auto float16 = vectorlite::ParseVectorType("bfloat16"); + EXPECT_TRUE(float16); +} + TEST(CreateVectorSpace, ShouldWorkWithValidInput) { - auto l2 = vectorlite::CreateNamedVectorSpace(3, vectorlite::DistanceType::L2, - "my_vector", - vectorlite::VectorType::Float32); - ASSERT_TRUE(l2.ok()); - EXPECT_EQ(l2->distance_type, vectorlite::DistanceType::L2); - EXPECT_EQ(l2->normalize, false); - EXPECT_NE(l2->space, nullptr); - EXPECT_EQ(l2->dimension(), 3); - EXPECT_EQ(l2->vector_type, vectorlite::VectorType::Float32); - - auto ip = vectorlite::CreateNamedVectorSpace( - 4, vectorlite::DistanceType::InnerProduct, "my_vector", - vectorlite::VectorType::Float32); - ASSERT_TRUE(ip.ok()); - EXPECT_EQ(ip->distance_type, vectorlite::DistanceType::InnerProduct); - EXPECT_EQ(ip->normalize, false); - EXPECT_NE(ip->space, nullptr); - EXPECT_EQ(ip->dimension(), 4); - EXPECT_EQ(ip->vector_type, vectorlite::VectorType::Float32); - - auto cosine = vectorlite::CreateNamedVectorSpace( - 5, vectorlite::DistanceType::Cosine, "my_vector", - vectorlite::VectorType::Float32); - ASSERT_TRUE(cosine.ok()); - EXPECT_EQ(cosine->distance_type, vectorlite::DistanceType::Cosine); - EXPECT_EQ(cosine->normalize, true); - EXPECT_NE(cosine->space, nullptr); - EXPECT_EQ(cosine->dimension(), 5); - EXPECT_EQ(cosine->vector_type, vectorlite::VectorType::Float32); + for (auto vector_type : + {vectorlite::VectorType::Float32, vectorlite::VectorType::BFloat16}) { + auto l2 = vectorlite::CreateNamedVectorSpace( + 3, vectorlite::DistanceType::L2, "my_vector", vector_type); + ASSERT_TRUE(l2.ok()); + EXPECT_EQ(l2->distance_type, vectorlite::DistanceType::L2); + EXPECT_EQ(l2->normalize, false); + EXPECT_NE(l2->space, nullptr); + EXPECT_EQ(l2->dimension(), 3); + EXPECT_EQ(l2->vector_type, vector_type); + + auto ip = vectorlite::CreateNamedVectorSpace( + 4, vectorlite::DistanceType::InnerProduct, "my_vector", vector_type); + ASSERT_TRUE(ip.ok()); + EXPECT_EQ(ip->distance_type, vectorlite::DistanceType::InnerProduct); + EXPECT_EQ(ip->normalize, false); + EXPECT_NE(ip->space, nullptr); + EXPECT_EQ(ip->dimension(), 4); + EXPECT_EQ(ip->vector_type, vector_type); + + auto cosine = vectorlite::CreateNamedVectorSpace( + 5, vectorlite::DistanceType::Cosine, "my_vector", vector_type); + ASSERT_TRUE(cosine.ok()); + EXPECT_EQ(cosine->distance_type, vectorlite::DistanceType::Cosine); + EXPECT_EQ(cosine->normalize, true); + EXPECT_NE(cosine->space, nullptr); + EXPECT_EQ(cosine->dimension(), 5); + EXPECT_EQ(cosine->vector_type, vector_type); + } } TEST(CreateNamedVectorSpace, ShouldReturnErrorForDimOfZero) { - auto l2 = vectorlite::CreateNamedVectorSpace(0, vectorlite::DistanceType::L2, - "my_vector", - vectorlite::VectorType::Float32); - EXPECT_FALSE(l2.ok()); - - auto ip = vectorlite::CreateNamedVectorSpace( - 0, vectorlite::DistanceType::InnerProduct, "my_vector", - vectorlite::VectorType::Float32); - EXPECT_FALSE(ip.ok()); - - auto cosine = vectorlite::CreateNamedVectorSpace( - 0, vectorlite::DistanceType::Cosine, "my_vector", - vectorlite::VectorType::Float32); - EXPECT_FALSE(cosine.ok()); + for (auto vector_type : + {vectorlite::VectorType::Float32, vectorlite::VectorType::BFloat16}) { + auto l2 = vectorlite::CreateNamedVectorSpace( + 0, vectorlite::DistanceType::L2, "my_vector", vector_type); + EXPECT_FALSE(l2.ok()); + + auto ip = vectorlite::CreateNamedVectorSpace( + 0, vectorlite::DistanceType::InnerProduct, "my_vector", vector_type); + EXPECT_FALSE(ip.ok()); + + auto cosine = vectorlite::CreateNamedVectorSpace( + 0, vectorlite::DistanceType::Cosine, "my_vector", vector_type); + EXPECT_FALSE(cosine.ok()); + } +} + +static std::string VectorTypeToString(vectorlite::VectorType type) { + switch (type) { + case vectorlite::VectorType::Float32: + return "float32"; + case vectorlite::VectorType::BFloat16: + return "bfloat16"; + default: + return "unknown"; + } } TEST(NamedVectorSpace_FromString, ShouldWorkWithValidInput) { - // If distance type is not specifed, it should default to L2 - auto space = vectorlite::NamedVectorSpace::FromString("my_vec float32[3]"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, false); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); - EXPECT_EQ(3, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); - - space = vectorlite::NamedVectorSpace::FromString("my_vec float32[3] l2"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, false); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); - EXPECT_EQ(3, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); - - space = - vectorlite::NamedVectorSpace::FromString("my_vec float32[10086] cosine"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, true); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::Cosine); - EXPECT_EQ(10086, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); - - space = vectorlite::NamedVectorSpace::FromString("my_vec float32[42] ip"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, false); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::InnerProduct); - EXPECT_EQ(42, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); + for (auto vector_type : + {vectorlite::VectorType::Float32, vectorlite::VectorType::BFloat16}) { + // If distance type is not specifed, it should default to L2 + std::string vector_type_str = VectorTypeToString(vector_type); + auto space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[3]", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, false); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); + EXPECT_EQ(3, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + + space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[3] l2", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, false); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); + EXPECT_EQ(3, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + + space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[10086] cosine", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, true); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::Cosine); + EXPECT_EQ(10086, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + + space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[42] ip", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, false); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::InnerProduct); + EXPECT_EQ(42, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + } } diff --git a/vectorlite/vector_test.cpp b/vectorlite/vector_test.cpp index 2cdaf77..acd79a7 100644 --- a/vectorlite/vector_test.cpp +++ b/vectorlite/vector_test.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "vector_space.h" +#include "vector_view.h" TEST(VectorTest, FromJSON) { // Test valid JSON input @@ -69,15 +70,19 @@ TEST(VectorDistance, ShouldWork) { // Test valid input vectorlite::Vector v1({1.0, 2.0, 3.0}); vectorlite::Vector v2({4.0, 5.0, 6.0}); - auto distance = Distance(v1, v2, vectorlite::DistanceType::L2); + + vectorlite::VectorView v1_view(v1); + vectorlite::VectorView v2_view(v2); + + auto distance = Distance(v1_view, v2_view, vectorlite::DistanceType::L2); EXPECT_TRUE(distance.ok()); EXPECT_FLOAT_EQ(*distance, 27); - distance = Distance(v2, v1, vectorlite::DistanceType::InnerProduct); + distance = Distance(v2_view, v1_view, vectorlite::DistanceType::InnerProduct); EXPECT_TRUE(distance.ok()); EXPECT_FLOAT_EQ(*distance, -31); - distance = Distance(v1, v2, vectorlite::DistanceType::Cosine); + distance = Distance(v1_view, v2_view, vectorlite::DistanceType::Cosine); EXPECT_TRUE(distance.ok()); // On osx arm64, no vectoration is used and the following test fails. // EXPECT_FLOAT_EQ(*distance, 0.025368214); @@ -87,9 +92,12 @@ TEST(VectorDistance, ShouldWork) { // Test 0 dimension vectorlite::Vector v3; vectorlite::Vector v4; + + vectorlite::VectorView v3_view(v3); + vectorlite::VectorView v4_view(v4); for (auto space : {vectorlite::DistanceType::L2, vectorlite::DistanceType::InnerProduct}) { - distance = Distance(v3, v4, space); + distance = Distance(v3_view, v4_view, space); EXPECT_TRUE(distance.ok()); } } diff --git a/vectorlite/vector_view.cpp b/vectorlite/vector_view.cpp deleted file mode 100644 index adb20e8..0000000 --- a/vectorlite/vector_view.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "vector_view.h" - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "rapidjson/document.h" -#include "rapidjson/error/en.h" -#include "rapidjson/stringbuffer.h" -#include "rapidjson/writer.h" -#include "vector.h" - -namespace vectorlite { - -VectorView::VectorView(const Vector& vector) : data_(vector.data()) {} - -absl::StatusOr VectorView::FromBlob(std::string_view blob) { - if (blob.size() % sizeof(float) != 0) { - return absl::InvalidArgumentError("Blob size is not a multiple of float"); - } - return VectorView(absl::MakeSpan(reinterpret_cast(blob.data()), - blob.size() / sizeof(float))); -} - -std::string VectorView::ToJSON() const { - rapidjson::Document doc; - doc.SetArray(); - - auto& allocator = doc.GetAllocator(); - for (float v : data_) { - doc.PushBack(v, allocator); - } - - rapidjson::StringBuffer buf; - rapidjson::Writer writer(buf); - doc.Accept(writer); - - return buf.GetString(); -} - -std::string_view VectorView::ToBlob() const { - return std::string_view(reinterpret_cast(data_.data()), - data_.size() * sizeof(float)); -} - -} // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/vector_view.h b/vectorlite/vector_view.h index dfadfd6..61dfbf6 100644 --- a/vectorlite/vector_view.h +++ b/vectorlite/vector_view.h @@ -5,37 +5,76 @@ #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "hwy/base.h" +#include "macros.h" +#include "rapidjson/document.h" +#include "rapidjson/error/en.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/writer.h" +#include "util.h" namespace vectorlite { -class Vector; +template +class GenericVector; -// VectorView is a read-only view of a vector, like std::string_view is to -// std::string. -class VectorView { +// GenericVectorView is a read-only view of a vector, like what std::string_view +// is to std::string. +template +class GenericVectorView : private CopyAssignBase, + private CopyCtorBase, + private MoveCtorBase, + private MoveAssignBase { public: - VectorView() = default; - VectorView(const VectorView&) = default; - VectorView(VectorView&&) = default; + GenericVectorView() = default; + GenericVectorView(const GenericVectorView&) = default; + GenericVectorView(GenericVectorView&&) = default; - VectorView(const Vector& vector); - explicit VectorView(absl::Span data) : data_(data) {} + GenericVectorView(const GenericVector& vector) + : data_(vector.data()) {} + explicit GenericVectorView(absl::Span data) : data_(data) {} - VectorView& operator=(const VectorView&) = default; - VectorView& operator=(VectorView&&) = default; + GenericVectorView& operator=(const GenericVectorView&) = default; + GenericVectorView& operator=(GenericVectorView&&) = default; - static absl::StatusOr FromBlob(std::string_view blob); + static absl::StatusOr> FromBlob(std::string_view blob) { + if (blob.size() % sizeof(T) != 0) { + return absl::InvalidArgumentError("Blob size is not a multiple of float"); + } + return GenericVectorView(absl::MakeSpan( + reinterpret_cast(blob.data()), blob.size() / sizeof(T))); + }; - std::string ToJSON() const; + std::string ToJSON() const { + rapidjson::Document doc; + doc.SetArray(); - std::string_view ToBlob() const; + auto& allocator = doc.GetAllocator(); + for (T v : data_) { + doc.PushBack(hwy::ConvertScalarTo(v), allocator); + } + + rapidjson::StringBuffer buf; + rapidjson::Writer writer(buf); + doc.Accept(writer); + + return buf.GetString(); + }; + + std::string_view ToBlob() const { + return std::string_view(reinterpret_cast(data_.data()), + data_.size() * sizeof(T)); + }; std::size_t dim() const { return data_.size(); } - absl::Span data() const { return data_; } + absl::Span data() const { return data_; } private: - absl::Span data_; + absl::Span data_; }; +using VectorView = GenericVectorView; +using BF16VectorView = GenericVectorView; + } // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/vector_view_test.cpp b/vectorlite/vector_view_test.cpp index 4b46321..c8ec1ae 100644 --- a/vectorlite/vector_view_test.cpp +++ b/vectorlite/vector_view_test.cpp @@ -21,6 +21,9 @@ TEST(VectorViewTest, Reversible_ToBinary_FromBinary) { TEST(VectorViewTest, FromBinaryShouldFailWithInvalidInput) { auto v1 = vectorlite::VectorView::FromBlob(std::string_view("aaa")); EXPECT_FALSE(v1.ok()); + + auto v2 = vectorlite::BF16VectorView::FromBlob(std::string_view("aaa")); + EXPECT_FALSE(v2.ok()); } TEST(VectorViewTest, ToJSON) { diff --git a/vectorlite/virtual_table.cpp b/vectorlite/virtual_table.cpp index e8ad984..c8920bd 100644 --- a/vectorlite/virtual_table.cpp +++ b/vectorlite/virtual_table.cpp @@ -20,8 +20,10 @@ #include "hnswlib/hnswlib.h" #include "index_options.h" #include "macros.h" +#include "quantization.h" #include "sqlite3ext.h" #include "util.h" +#include "vector.h" #include "vector_space.h" #include "vector_view.h" @@ -599,7 +601,42 @@ constexpr bool IsRowidOutOfRange(sqlite3_int64 rowid) { std::numeric_limits::max()); } -// Only insert is supported for now +int VirtualTable::InsertOrUpdateVector(VectorView vector, Cursor::Rowid rowid) { + try { + if (space_.vector_type == vectorlite::VectorType::Float32) { + if (!space_.normalize) { + index_->addPoint(vector.data().data(), rowid, + index_->allow_replace_deleted_); + } else { + Vector normalized_vector = Vector::Normalize(vector); + index_->addPoint(normalized_vector.data().data(), rowid, + index_->allow_replace_deleted_); + } + } else if (space_.vector_type == vectorlite::VectorType::BFloat16) { + BF16Vector bf16_vector = Quantize(vector); + if (!space_.normalize) { + index_->addPoint(bf16_vector.data().data(), rowid, + index_->allow_replace_deleted_); + } else { + BF16Vector normalized_vector = bf16_vector.Normalize(); + index_->addPoint(normalized_vector.data().data(), rowid, + index_->allow_replace_deleted_); + } + + } else { + SetZErrMsg(&this->zErrMsg, "Unrecognized vector type %d", + space_.vector_type); + return SQLITE_ERROR; + } + + } catch (const std::runtime_error& e) { + SetZErrMsg(&this->zErrMsg, "Failed to insert row %lld due to: %s", rowid, + e.what()); + return SQLITE_ERROR; + } + return SQLITE_OK; +} + int VirtualTable::Update(sqlite3_vtab* pVTab, int argc, sqlite3_value** argv, sqlite_int64* pRowid) { VirtualTable* vtab = static_cast(pVTab); @@ -646,20 +683,7 @@ int VirtualTable::Update(sqlite3_vtab* pVTab, int argc, sqlite3_value** argv, return SQLITE_ERROR; } - try { - if (!vtab->space_.normalize) { - vtab->index_->addPoint(vector->data().data(), rowid, true); - } else { - Vector normalized_vector = Vector::Normalize(*vector); - vtab->index_->addPoint(normalized_vector.data().data(), rowid, true); - } - - } catch (const std::runtime_error& e) { - SetZErrMsg(&vtab->zErrMsg, "Failed to insert row %lld due to: %s", - rowid, e.what()); - return SQLITE_ERROR; - } - return SQLITE_OK; + return vtab->InsertOrUpdateVector(*vector, rowid); } else { SetZErrMsg(&vtab->zErrMsg, "Failed to perform insertion due to: %s", absl::StatusMessageAsCStr(vector.status())); @@ -729,23 +753,7 @@ int VirtualTable::Update(sqlite3_vtab* pVTab, int argc, sqlite3_value** argv, return SQLITE_ERROR; } - try { - if (!vtab->space_.normalize) { - vtab->index_->addPoint(vector->data().data(), rowid, - vtab->index_->allow_replace_deleted_); - } else { - Vector normalized_vector = Vector::Normalize(*vector); - vtab->index_->addPoint(normalized_vector.data().data(), rowid, - vtab->index_->allow_replace_deleted_); - } - - } catch (const std::runtime_error& e) { - SetZErrMsg(&vtab->zErrMsg, "Failed to update row %lld due to: %s", - rowid, e.what()); - return SQLITE_ERROR; - } - - return SQLITE_OK; + return vtab->InsertOrUpdateVector(*vector, rowid); } else { SetZErrMsg(&vtab->zErrMsg, "Failed to perform row %lld due to: %s", rowid, absl::StatusMessageAsCStr(vector.status())); diff --git a/vectorlite/virtual_table.h b/vectorlite/virtual_table.h index 021a509..7e4d624 100644 --- a/vectorlite/virtual_table.h +++ b/vectorlite/virtual_table.h @@ -17,9 +17,6 @@ namespace vectorlite { -// Note there shouldn't be any virtual functions in this class. -// Because VirtualTable* is expected to be static_cast-ed to sqlite3_vtab*. -// vptr could cause UB. class VirtualTable : public sqlite3_vtab { public: // No virtual function @@ -102,6 +99,7 @@ class VirtualTable : public sqlite3_vtab { private: absl::StatusOr GetVectorByRowid(int64_t rowid) const; + int InsertOrUpdateVector(VectorView vector, Cursor::Rowid rowid); NamedVectorSpace space_; std::unique_ptr> index_;