Skip to content

Commit

Permalink
feat: add embed sentences mock
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloSanchi committed Jul 27, 2024
1 parent ae277c7 commit 5511b3c
Showing 1 changed file with 39 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package jchunk.chunker.semantic;

import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.springframework.ai.embedding.EmbeddingModel;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;

public class SemanticChunkerTest {

@Mock
private EmbeddingModel embeddingModel;

private final EmbeddingModel embeddingModel;
private SemanticChunker semanticChunker;

public SemanticChunkerTest() {
this.embeddingModel = Mockito.mock(EmbeddingModel.class);
}

public void configure() {
this.semanticChunker = new SemanticChunker(embeddingModel);
}
Expand Down Expand Up @@ -44,10 +45,9 @@ public void splitSentenceDefaultStrategyTest() {
assertThat(result).isNotNull();
assertThat(result.size()).isEqualTo(expectedResult.size());

assertThat(result.get(0).getContent()).isEqualTo(expectedResult.get(0).getContent());
assertThat(result.get(1).getContent()).isEqualTo(expectedResult.get(1).getContent());
assertThat(result.get(2).getContent()).isEqualTo(expectedResult.get(2).getContent());
assertThat(result.get(3).getContent()).isEqualTo(expectedResult.get(3).getContent());
for (int i = 0; i < result.size(); i++) {
assertThat(result.get(i).getContent()).isEqualTo(expectedResult.get(i).getContent());
}
}

@Test
Expand Down Expand Up @@ -183,6 +183,35 @@ public void combineSentencesWithInputIsEmptyTest() {
.hasMessage("The list of sentences cannot be empty");
}

@Test
public void embedSentencesTest() {
configure();

Mockito.when(embeddingModel.embed(Mockito.anyList()))
.thenReturn(List.of(List.of(1.0, 2.0, 3.0), List.of(4.0, 5.0, 6.0)));

List<SemanticChunker.Sentence> sentences = List.of(
SemanticChunker.Sentence.builder().combined("This is a test sentence.").build(),
SemanticChunker.Sentence.builder().combined("How are u?").build());

List<SemanticChunker.Sentence> expectedResult = List.of(
SemanticChunker.Sentence.builder()
.combined("This is a test sentence.")
.embedding(List.of(1.0, 2.0, 3.0))
.build(),
SemanticChunker.Sentence.builder().combined("How are u?").embedding(List.of(4.0, 5.0, 6.0)).build());

List<SemanticChunker.Sentence> result = this.semanticChunker.embedSentences(sentences);

assertThat(result).isNotNull();

for (int i = 0; i < result.size(); i++) {
assertThat(result.get(i).getCombined()).isEqualTo(expectedResult.get(i).getCombined());
assertThat(result.get(i).getEmbedding()).isEqualTo(expectedResult.get(i).getEmbedding());
}

}

@Test
public void testIdenticalVectors() {
configure();
Expand Down

0 comments on commit 5511b3c

Please sign in to comment.