Skip to content

Commit

Permalink
Merge pull request #20 from pinecone-io/rshah/formatVectors
Browse files Browse the repository at this point in the history
Reformat sparse vectors
  • Loading branch information
rohanshah18 authored Oct 25, 2023
2 parents 9455bb0 + 623c6a5 commit 6150c65
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 65 deletions.
64 changes: 27 additions & 37 deletions src/it/resources/sample.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,25 @@
2,
3
],
"metadata": "",
"sparse_id": [
0,
2
],
"sparse_values": [
5.5,
5
]
"metadata": "{\"hello\": [\"world\", \"you\"], \"numbers\": \"or not\", \"actual_number\": 5.2, \"round\": 3}",
"sparse_values": {
"indices": [
0,
2
],
"values": [
5.5,
5
]
}
},
{
"id": "v2",
"namespace": "default",
"values": [
3,
2,
1
],
"metadata": "",
"sparse_id": [],
"sparse_values": []
]
},
{
"namespace": "default",
Expand All @@ -37,21 +35,17 @@
9
],
"id": "v3",
"metadata": "",
"sparse_id": [],
"sparse_values": []
"metadata": ""
},
{
"id": "v4",
"namespace": "default",
"namespace": "",
"values": [
1,
1,
2
],
"metadata": "{\"key\": \"value\"}",
"sparse_id": [],
"sparse_values": []
"metadata": "{\"key\": \"value\"}"
},
{
"id": "v5",
Expand All @@ -62,8 +56,14 @@
8
],
"metadata": "{\"key\": \"value\"}",
"sparse_id": [],
"sparse_values": []
"sparse_values": {
"indices": [
1
],
"values": [
4
]
}
},
{
"id": "v6",
Expand All @@ -73,26 +73,16 @@
21,
34
],
"metadata": "",
"sparse_id": [],
"sparse_values": []
"metadata": ""
},
{
"id": "v7",
"namespace": "default",
"values": [
5,
5,
5
6,
7
],
"metadata": "{\"hello\": [\"world\", \"you\"], \"numbers\": \"or not\", \"actual_number\": 5.2, \"round\": 3}",
"sparse_id": [
0,
2
],
"sparse_values": [
5.5,
5
]
"metadata": "{\"hello\": [\"world\", \"you\"], \"numbers\": \"or not\", \"actual_number\": 5.2, \"round\": 3}"
}
]
30 changes: 18 additions & 12 deletions src/main/scala/io/pinecone/spark/pinecone/PineconeDataWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ case class PineconeDataWriter(
override def write(record: InternalRow): Unit = {
try {
val id = record.getUTF8String(0).toString
val namespace = record.getUTF8String(1).toString
val namespace = if(!record.isNullAt(1)) record.getUTF8String(1).toString else ""
val values = record.getArray(2).toFloatArray().map(float2Float).toIterable
val metadata = record.getUTF8String(3).toString
val sparseId = record.getArray(4).toIntArray().map(int2Integer).toIterable
val sparseValues = record.getArray(5).toFloatArray().map(float2Float).toIterable

if (id.length > MAX_ID_LENGTH) {
throw VectorIdTooLongException(id)
Expand All @@ -56,16 +53,25 @@ case class PineconeDataWriter(
vectorBuilder.addAllValues(values.asJava)
}

if (sparseId.nonEmpty && sparseValues.nonEmpty) {
val sparseDataBuilder = SparseValues.newBuilder()
.addAllIndices(sparseId.asJava)
.addAllValues(sparseValues.asJava)

vectorBuilder.setSparseValues(sparseDataBuilder.build())
if (!record.isNullAt(3)) {
val metadata = record.getUTF8String(3).toString
val metadataStruct = parseAndValidateMetadata(id, metadata)
vectorBuilder.setMetadata(metadataStruct)
}

val metadataStruct = parseAndValidateMetadata(id, metadata)
vectorBuilder.setMetadata(metadataStruct)
if (!record.isNullAt(4)) {
val sparseVectorStruct = record.getStruct(4, 2)
if (!sparseVectorStruct.isNullAt(0) && !sparseVectorStruct.isNullAt(1)) {
val sparseId = sparseVectorStruct.getArray(0).toIntArray().map(int2Integer).toIterable
val sparseValues = sparseVectorStruct.getArray(1).toFloatArray().map(float2Float).toIterable

val sparseDataBuilder = SparseValues.newBuilder()
.addAllIndices(sparseId.asJava)
.addAllValues(sparseValues.asJava)

vectorBuilder.setSparseValues(sparseDataBuilder.build())
}
}

val vector = vectorBuilder
.build()
Expand Down
19 changes: 11 additions & 8 deletions src/main/scala/io/pinecone/spark/pinecone/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@ package io.pinecone.spark
import com.fasterxml.jackson.databind.ObjectMapper
import com.google.protobuf.{ListValue, Struct, Value}
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.types.{ArrayType, FloatType, IntegerType, StringType, StructType}
import org.apache.spark.sql.types.{ArrayType, FloatType, IntegerType, StringType, StructField, StructType}

import scala.collection.JavaConverters._

package object pinecone {
// TODO: Fill out the schema (mainly solve the whole metadata issue)
val COMMON_SCHEMA: StructType =
new StructType()
.add("id", StringType)
.add("namespace", StringType)
.add("values", ArrayType(FloatType))
.add("metadata", StringType)
.add("sparse_id", ArrayType(IntegerType, containsNull = true), nullable = true)
.add("sparse_values", ArrayType(FloatType, containsNull = true), nullable = true)
.add("id", StringType, nullable = false)
.add("namespace", StringType, nullable = true)
.add("values", ArrayType(FloatType, containsNull = false), nullable = false)
.add("metadata", StringType, nullable = true)
.add("sparse_values", StructType(
Array(
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = false),
StructField("values", ArrayType(FloatType, containsNull = false), nullable = false)
)
), nullable = true)

private[pinecone] val MAX_ID_LENGTH = 512
private[pinecone] val MAX_METADATA_SIZE = 5 * math.pow(10, 3) // 5KB
Expand Down
9 changes: 9 additions & 0 deletions src/test/resources/invalidUpsertInput1.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[
{
"values": [
1,
2,
3
]
}
]
5 changes: 5 additions & 0 deletions src/test/resources/invalidUpsertInput2.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{
"id": "v1"
}
]
12 changes: 12 additions & 0 deletions src/test/resources/invalidUpsertInput3.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{
"id": "v1",
"values": [
3,
2,
1
],
"sparse_values": {
}
}
]
16 changes: 16 additions & 0 deletions src/test/resources/invalidUpsertInput4.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[
{
"id": "v1",
"values": [
3,
2,
1
],
"sparse_values": {
"values": [
100,
101
]
}
}
]
16 changes: 16 additions & 0 deletions src/test/resources/invalidUpsertInput5.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[
{
"id": "v1",
"values": [
3,
2,
1
],
"sparse_values": {
"indices": [
1,
2
]
}
}
]
18 changes: 18 additions & 0 deletions src/test/resources/invalidUpsertInput6.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[
{
"id": "v1",
"values": [
3,
2,
1
],
"sparse_values": {
"indices": [
null
],
"values": [
1
]
}
}
]
18 changes: 18 additions & 0 deletions src/test/resources/invalidUpsertInput7.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[
{
"id": "v1",
"values": [
3,
2,
1
],
"sparse_values": {
"indices": [
1
],
"values": [
null
]
}
}
]
7 changes: 0 additions & 7 deletions src/test/resources/sample.jsonl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package io.pinecone.spark.pinecone

import org.apache.spark.sql.{SaveMode, SparkSession}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should

class ParseCommonSchemaTest extends AnyFlatSpec with should.Matchers {
private val spark: SparkSession = SparkSession.builder()
.appName("SchemaValidationTest")
.master("local[2]")
.getOrCreate()

private val inputFilePath = System.getProperty("user.dir") + "/src/test/resources"

private val apiKey = "some_api_key"
private val environment = "us-east4-gcp"
private val projectName = "f8e8d52"
private val indexName = "step-test"

private val pineconeOptions: Map[String, String] = Map(
PineconeOptions.PINECONE_API_KEY_CONF -> apiKey,
PineconeOptions.PINECONE_ENVIRONMENT_CONF -> environment,
PineconeOptions.PINECONE_PROJECT_NAME_CONF -> projectName,
PineconeOptions.PINECONE_INDEX_NAME_CONF -> indexName
)

def afterAll(): Unit = {
if (spark != null) {
spark.stop()
}
}

def testInvalidJSON(file: String, testName: String): Unit = {
it should testName in {
val sparkException = intercept[org.apache.spark.SparkException] {
val df = spark.read
.option("multiLine", value = true)
.option("mode", "PERMISSIVE")
.schema(COMMON_SCHEMA)
.json(file)
.repartition(2)

df.write
.options(pineconeOptions)
.format("io.pinecone.spark.pinecone.Pinecone")
.mode(SaveMode.Append)
.save()
}
sparkException
.getCause
.toString should include("java.lang.NullPointerException: Null value appeared in non-nullable field:")
}
}

testInvalidJSON(s"$inputFilePath/invalidUpsertInput1.jsonl",
"throw exception for missing id")
testInvalidJSON(s"$inputFilePath/invalidUpsertInput2.jsonl",
"throw exception for missing values")
testInvalidJSON(s"$inputFilePath/invalidUpsertInput3.jsonl",
"throw exception for missing sparse vector indices and values if sparse_values is defined")
testInvalidJSON(s"$inputFilePath/invalidUpsertInput4.jsonl",
"throw exception for missing sparse vector indices if sparse_values and its values are defined")
testInvalidJSON(s"$inputFilePath/invalidUpsertInput5.jsonl",
"throw exception for missing sparse vector values if sparse_values and its indices are defined")
testInvalidJSON(s"$inputFilePath/invalidUpsertInput6.jsonl",
"throw exception for null in sparse vector indices")
testInvalidJSON(s"$inputFilePath/invalidUpsertInput7.jsonl",
"throw exception for null in sparse vector values")
}
2 changes: 1 addition & 1 deletion version.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ThisBuild / version := "0.2.0"
ThisBuild / version := "0.2.1"

0 comments on commit 6150c65

Please sign in to comment.