Skip to content

Commit

Permalink
added more embedding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
craiglabenz committed Mar 12, 2024
1 parent 7628efe commit eabd913
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 1 deletion.
10 changes: 10 additions & 0 deletions packages/mediapipe-core/lib/src/io/containers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class Embedding extends BaseEmbedding {
required int headIndex,
String? headName,
}) : _floatEmbedding = null,
_headIndex = headIndex,
_headName = headName,
_quantizedEmbedding = quantizedEmbedding,
_pointer = null,
type = EmbeddingType.quantized;
Expand All @@ -184,6 +186,8 @@ class Embedding extends BaseEmbedding {
required int headIndex,
String? headName,
}) : _floatEmbedding = floatEmbedding,
_headIndex = headIndex,
_headName = headName,
_quantizedEmbedding = null,
_pointer = null,
type = EmbeddingType.float;
Expand Down Expand Up @@ -233,6 +237,9 @@ class Embedding extends BaseEmbedding {
Uint8List? get quantizedEmbedding =>
_quantizedEmbedding ??= _getQuantizedEmbedding();
Uint8List? _getQuantizedEmbedding() {
if (type != EmbeddingType.quantized) {
return null;
}
if (_pointer.isNullOrNullPointer) {
throw Exception(
'Could not determine value for Embedding.quantizedEmbedding',
Expand All @@ -248,6 +255,9 @@ class Embedding extends BaseEmbedding {
@override
Float32List? get floatEmbedding => _floatEmbedding ??= _getFloatEmbedding();
Float32List? _getFloatEmbedding() {
if (type != EmbeddingType.float) {
return null;
}
if (_pointer.isNullOrNullPointer) {
throw Exception('Could not determine value for Embedding.floatEmbedding');
}
Expand Down
46 changes: 46 additions & 0 deletions packages/mediapipe-core/lib/src/io/test_utils.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// found in the LICENSE file.

import 'dart:ffi';
import 'dart:math';
import 'dart:typed_data';
import 'package:ffi/ffi.dart';

import 'package:mediapipe_core/src/io/mediapipe_core.dart';
Expand Down Expand Up @@ -56,3 +58,47 @@ void populateClassifications(
classifications.head_name = headName.copyToNative();
classifications.head_index = headIndex;
}

/// Hydrates a faked [core_bindings.Embedding] object.
void populateEmbedding(
core_bindings.Embedding embedding, {
bool quantize = false,
bool l2Normalize = false,
int length = 100,
String headName = 'response_encoding',
int headIndex = 1,
}) {
embedding.values_count = length;
embedding.head_name = headName.copyToNative();
embedding.head_index = headIndex;

Random rnd = Random();

if (quantize) {
embedding.quantized_embedding = Uint8List.fromList(
_genInts(length, rnd: rnd).toList(),
).copyToNative();
} else {
embedding.float_embedding = Float32List.fromList(
_genFloats(length, l2Normalize: l2Normalize, rnd: rnd).toList(),
).copyToNative();
}
}

Iterable<int> _genInts(int count, {required Random rnd}) sync* {
int index = 0;
while (index < count) {
yield rnd.nextInt(127);
index++;
}
}

Iterable<double> _genFloats(int count,
{required bool l2Normalize, required Random rnd}) sync* {
int index = 0;
while (index < count) {
final dbl = rnd.nextDouble();
yield l2Normalize ? dbl : dbl * 127;
index++;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import 'package:mediapipe_text/src/io/third_party/mediapipe/generated/mediapipe_
as bindings;

void main() {
group('TextClassifierResult.structToDart should', () {
group('TextClassifierResult.native should', () {
test('load an empty object', () {
final Pointer<bindings.TextClassifierResult> ptr =
calloc<bindings.TextClassifierResult>();
Expand Down
99 changes: 99 additions & 0 deletions packages/mediapipe-task-text/test/text_embedder_executor_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2014 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// `native-assets` tag allows test runs to opt in or out of running integration
// tests via `flutter test -x native-assets` or `flutter test -t native-assets`
@Tags(['native-assets'])

import 'dart:io' as io;
import 'package:flutter_test/flutter_test.dart';
import 'package:path/path.dart' as path;
import 'package:mediapipe_core/io.dart';
import 'package:mediapipe_text/io.dart';

void main() {
final pathToModel = path.joinAll([
io.Directory.current.absolute.path,
'example/assets/universal_sentence_encoder.tflite',
]);
final modelBytes = io.File(pathToModel).readAsBytesSync();

group('TextEmbedderExecutor should', () {
test('run a task', () {
final executor = TextEmbedderExecutor(
TextEmbedderOptions.fromAssetBuffer(modelBytes),
);
final TextEmbedderResult result = executor.embed('Hello, world!');
expect(result.embeddings, isNotEmpty);
result.dispose();
executor.dispose();
});

test('run multiple tasks', () {
final executor = TextEmbedderExecutor(
TextEmbedderOptions.fromAssetBuffer(modelBytes),
);
final TextEmbedderResult result = executor.embed('Hello, world!');
expect(result.embeddings, isNotEmpty);
final TextEmbedderResult result2 = executor.embed('Hello, world!');
expect(result2.embeddings, isNotEmpty);
result.dispose();
executor.dispose();
});

test('unpack a result', () {
final executor = TextEmbedderExecutor(
TextEmbedderOptions.fromAssetBuffer(modelBytes),
);
final TextEmbedderResult result = executor.embed('Hello, world!');
final embedding = result.embeddings.first;
expect(embedding.headName, 'response_encoding');
expect(embedding.quantizedEmbedding, isNull);
expect(embedding.floatEmbedding, isNotNull);
expect(embedding.length, 100);
expect(embedding.type, equals(EmbeddingType.float));
expect(embedding.floatEmbedding![0], closeTo(1.7475, 0.0001));
result.dispose();
executor.dispose();
});

test('quantize results when requested', () {
final executor = TextEmbedderExecutor(
TextEmbedderOptions.fromAssetBuffer(
modelBytes,
embedderOptions: EmbedderOptions(quantize: true),
),
);
final TextEmbedderResult result = executor.embed('Hello, world!');
final embedding = result.embeddings.first;
expect(embedding.headName, 'response_encoding');
expect(embedding.quantizedEmbedding, isNotNull);
expect(embedding.floatEmbedding, isNull);
expect(embedding.quantizedEmbedding![0], 127);
expect(embedding.length, 100);
expect(embedding.type, equals(EmbeddingType.quantized));
result.dispose();
executor.dispose();
});

test('normalize', () {
final executor = TextEmbedderExecutor(
TextEmbedderOptions.fromAssetBuffer(
modelBytes,
embedderOptions: EmbedderOptions(l2Normalize: true),
),
);
final TextEmbedderResult result = executor.embed('Hello, world!');
final embedding = result.embeddings.first;
expect(embedding.headName, 'response_encoding');
expect(embedding.quantizedEmbedding, isNull);
expect(embedding.floatEmbedding, isNotNull);
expect(embedding.floatEmbedding![0], closeTo(0.1560, 0.0001));
expect(embedding.length, 100);
expect(embedding.type, equals(EmbeddingType.float));
result.dispose();
executor.dispose();
});
});
}
51 changes: 51 additions & 0 deletions packages/mediapipe-task-text/test/text_embedder_result_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2014 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'dart:ffi';
import 'package:ffi/ffi.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:mediapipe_core/src/io/test_utils.dart';
import 'package:mediapipe_text/io.dart';
import 'package:mediapipe_core/src/io/third_party/mediapipe/generated/mediapipe_common_bindings.dart'
as core_bindings;
import 'package:mediapipe_text/src/io/third_party/mediapipe/generated/mediapipe_text_bindings.dart'
as bindings;

void main() {
group('TextEmbedderResult.native should', () {
test('load an empty object', () {
final Pointer<bindings.TextEmbedderResult> ptr =
calloc<bindings.TextEmbedderResult>();
// These fields are provided by the real MediaPipe implementation, but
// Dart ignores them because they are meaningless in context of text tasks
ptr.ref.embeddings_count = 0;
ptr.ref.has_timestamp_ms = false;

final result = TextEmbedderResult.native(ptr);
expect(result.embeddings, isEmpty);
});

test('load a hydrated object', () {
final Pointer<bindings.TextEmbedderResult> resultPtr =
calloc<bindings.TextEmbedderResult>();

final embeddingsPtr = calloc<core_bindings.Embedding>(2);
populateEmbedding(embeddingsPtr[0], length: 50);
populateEmbedding(embeddingsPtr[1], length: 25);

resultPtr.ref.embeddings_count = 2;
resultPtr.ref.embeddings = embeddingsPtr;
resultPtr.ref.has_timestamp_ms = false;

final result = TextEmbedderResult.native(resultPtr);
expect(result.embeddings, hasLength(2));
final embedding = result.embeddings.take(1).toList().first;
expect(embedding.type, EmbeddingType.float);
expect(embedding.length, 50);
final embedding2 = result.embeddings.skip(1).take(1).toList().last;
expect(embedding2.type, EmbeddingType.float);
expect(embedding2.length, 25);
}, timeout: const Timeout(Duration(milliseconds: 10)));
});
}

0 comments on commit eabd913

Please sign in to comment.