diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f98b90..963c512 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,5 +26,9 @@ ## 1.2.0 -- Added support Language Translation APIs -- Made response results nullable, to handle errors better. \ No newline at end of file +- Added support for Language Translation APIs +- Made response results nullable, to handle errors better. + +## 1.3.0 + +- Added support for Text Chat APIs \ No newline at end of file diff --git a/README.md b/README.md index c99fdfd..cb67ec9 100644 --- a/README.md +++ b/README.md @@ -189,3 +189,4 @@ Supported Models: - [x] Image Generation - [x] Text Classification - [x] Language Translation +- [x] Text Chat \ No newline at end of file diff --git a/example/cloudflare_ai_example.dart b/example/cloudflare_ai_example.dart index 47c7e62..9331dc3 100644 --- a/example/cloudflare_ai_example.dart +++ b/example/cloudflare_ai_example.dart @@ -1,10 +1,11 @@ import 'dart:io'; import 'dart:typed_data'; import 'package:cloudflare_ai/cloudflare_ai.dart'; +import 'package:cloudflare_ai/src/text_chat/text_chat.dart'; void main() async { - String accountId = ""; - String apiKey = ""; + String accountId = "Your Account ID"; + String apiKey = "Your API Key"; // Text Generation // Initialize a TextGenerationModel @@ -108,4 +109,29 @@ void main() async { } else { print(languageTranslationRes..errors.map((e) => e.toJson()).toList()); } + + // Text Chat + // Initialize a TextChatModel + TextChatModel textChatModel = TextChatModel( + accountId: accountId, + apiKey: apiKey, + model: TextChatModels.GEMMA_7B_IT, + ); + + // Load any previous conversations + textChatModel.loadMessages([ + { + "role": "user", + "content": "Hello!", + }, + { + "role": "assistant", + "content": "Hello! How may I help you?", + }, + ]); + + // Send a new message + ChatMessage chatRes = await textChatModel.chat("Who are you?"); + + print(chatRes.content); } diff --git a/lib/src/text_chat/chat_message.dart b/lib/src/text_chat/chat_message.dart new file mode 100644 index 0000000..8ddec29 --- /dev/null +++ b/lib/src/text_chat/chat_message.dart @@ -0,0 +1,39 @@ +enum Role { + system, + user, + assistant; +} + +class ChatMessage { + late Role role; + late String content; + + ChatMessage({ + required this.role, + required this.content, + }); + + ChatMessage.fromJson(json) { + switch (json['role']) { + case 'system': + role = Role.assistant; + break; + case 'user': + role = Role.user; + break; + case 'assistant': + role = Role.assistant; + break; + default: + throw Exception("Invalid role"); + } + content = json['content']; + } + + Map toJson() { + Map res = {}; + res['role'] = role.name; + res['content'] = content; + return res; + } +} diff --git a/lib/src/text_chat/models.dart b/lib/src/text_chat/models.dart new file mode 100644 index 0000000..4672560 --- /dev/null +++ b/lib/src/text_chat/models.dart @@ -0,0 +1,39 @@ +// Defines the models, and specifies path to them in the API Endpoint +enum TextChatModels { + LLAMA_2_7B("@cf/meta/llama-2-7b-chat-fp16"), + LLAMA_2_7B_INT8("@cf/meta/llama-2-7b-chat-int8"), + MISTRAL_7B("@cf/mistral/mistral-7b-instruct-v0.1"), + CODE_LLAMA_7B("@hf/thebloke/codellama-7b-instruct-awq"), + CODE_LLAMA_2_13B("@hf/thebloke/llama-2-13b-chat-awq"), + ZEPHYR_7B("@hf/thebloke/zephyr-7b-beta-awq"), + MISTRAL_7B_AWQ_V01("@hf/thebloke/mistral-7b-instruct-v0.1-awq"), + MISTRAL_7B_AWQ_V02("@hf/mistral/mistral-7b-instruct-v0.2"), + OPENHERMES_MISTRAL_7B("@hf/thebloke/openhermes-2.5-mistral-7b-awq"), + NEURAL_CHAT_7B("@hf/thebloke/neural-chat-7b-v3-1-awq"), + LLAMA_GUARD_7B("@hf/thebloke/llamaguard-7b-awq"), + DEEPSEEK_CODER_6_7_BASE("@hf/thebloke/deepseek-coder-6.7b-base-awq"), + DEEPSEEK_CODER_6_7_INSTRUCT("@hf/thebloke/deepseek-coder-6.7b-instruct-awq"), + DEEPSEEK_MATH_7B_BASE("@cf/deepseek-ai/deepseek-math-7b-base"), + DEEPSEEK_MATH_7B_INSTRUCT("@cf/deepseek-ai/deepseek-math-7b-instruct"), + OPENCHAT_3_5("@cf/openchat/openchat-3.5-0106"), + PHI_2("@cf/phi/phi-2"), + TINYLAMA_1_1B("@cf/tinyllama/tinyllama-1.1b-chat-v1.0"), + DISCOLM_GERMAN_7B("@cf/thebloke/discolm-german-7b-v1-awq"), + QWEN_1_5_0_5B_CHAT("@cf/qwen/qwen-1.5.0.5b-chat"), + QWEN1_5_1_8B_CHAT("@cf/qwen/qwen1.5-1.8b-chat"), + QWEN_1_5_7B_CHAT_AWQ("@cf/qwen/qwen1.5-7b-chat-awq"), + QWEN_1_5_14B_CHAT_AWQ("@cf/qwen/qwen1.5-14b-chat-awq"), + FALCON_7B_INSTRUCT("@cf/tiiuae/falcon-7b-instruct"), + GEMMA_2B_IT_LORA("@cf/google/gemma-2b-it-lora"), + GEMMA_7B_IT("@hf/google/gemma-7b-it"), + GEMMA_7B_IT_LORA("@cf/google/gemma-7b-it-lora"), + HERMES_2_PRO_7B("@hf/nousresearch/hermes-2-pro-mistral-7b"), + LLAMA_2_7B_CHAT_HF_LORA("@cf/meta-llama/llama-2-7b-chat-hf-lora"), + LLAMA_3_8B_INSTRUCT("@hf/meta-llama/meta-llama-3-8b-instruct"), + UNA_CYBERTRON_7B_V2_BF16("@cf/fblgit/una-cybertron-7b-v2-bf16"), + STARLING_LM_7B_BETA("@hf/nexusflow/starling-lm-7b-beta"), + SQL_CODER_7B_2("@cf/defog/sqlcoder-7b-2"); + + const TextChatModels(this.value); + final String value; +} diff --git a/lib/src/text_chat/response.dart b/lib/src/text_chat/response.dart new file mode 100644 index 0000000..46383e6 --- /dev/null +++ b/lib/src/text_chat/response.dart @@ -0,0 +1,52 @@ +// Defines Repsonse received from the API +import '../models/error.dart'; + +class TextChatResponse { + late TextChatResult? result; + late bool success; + late List errors; + late List messages; + + TextChatResponse({ + required this.result, + required this.success, + errors, + messages, + }) { + this.errors = errors ?? []; + this.messages = messages ?? []; + } + + TextChatResponse.fromJson(Map json) { + result = TextChatResult.fromJson(json['result']); + success = json['success']; + errors = + (json['errors'] as List).map((e) => ErrorModel.fromJson(e)).toList(); + messages = json['messages']; + } + + Map toJson() { + final Map data = {}; + data['result'] = result?.toJson() ?? {}; + data['success'] = success; + data['errors'] = errors.map((e) => e.toJson()).toList(); + data['messages'] = messages; + return data; + } +} + +class TextChatResult { + String? response; + + TextChatResult({this.response}); + + TextChatResult.fromJson(Map json) { + response = json['response']; + } + + Map toJson() { + final Map data = {}; + data['response'] = response; + return data; + } +} diff --git a/lib/src/text_chat/text_chat.dart b/lib/src/text_chat/text_chat.dart new file mode 100644 index 0000000..1db634d --- /dev/null +++ b/lib/src/text_chat/text_chat.dart @@ -0,0 +1,66 @@ +import 'chat_message.dart'; +import 'models.dart'; +import 'response.dart'; +import '../services/network_service.dart'; +export 'models.dart'; +export 'response.dart'; +export 'chat_message.dart'; + +class TextChatModel { + late String accountId; + late String apiKey; + late TextChatModels model; + late bool raw; + List _messages = []; + NetworkService networkService = NetworkService(); + late String baseUrl; + + TextChatModel({ + required this.accountId, + required this.apiKey, + required this.model, + this.raw = true, + }) { + baseUrl = "https://api.cloudflare.com/client/v4/accounts/$accountId/ai/run"; + if (accountId.trim() == "") { + throw Exception("Account ID cannot be empty"); + } + if (apiKey.trim() == "") { + throw Exception("API Key cannot be empty"); + } + } + + // Get all chat messages + List get allMessages => _messages; + + // Load a previous chat + void loadMessages(List> messages) { + _messages = + messages.map((message) => ChatMessage.fromJson(message)).toList(); + } + + // Asynchronous function which returns the classification labels with their confidence of the text + Future chat(String message) async { + _messages.add( + ChatMessage( + role: Role.user, + content: message, + ), + ); + Map res = + await networkService.post("$baseUrl/${model.value}", apiKey, { + "messages": _messages, + }); + TextChatResponse response = TextChatResponse.fromJson(res['data']); + if (!response.success || response.result == null) { + throw Exception(response.errors); + } + _messages.add( + ChatMessage( + role: Role.assistant, + content: response.result?.response ?? "", + ), + ); + return _messages.last; + } +} diff --git a/lib/src/text_classification/text_classification.dart b/lib/src/text_classification/text_classification.dart index 87d211c..ace015c 100644 --- a/lib/src/text_classification/text_classification.dart +++ b/lib/src/text_classification/text_classification.dart @@ -1,4 +1,4 @@ -import 'package:cloudflare_ai/src/text_classification/models.dart'; +import 'models.dart'; import 'response.dart'; import '../services/network_service.dart'; export 'models.dart'; diff --git a/pubspec.yaml b/pubspec.yaml index a25c9f5..6984bb1 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: cloudflare_ai description: This is a Dart package for Cloudflare Workers AI. It currently supports Text Generation and Image Generation, Text Summarization, and Image Generation Models. -version: 1.2.0 +version: 1.3.0 repository: https://github.com/MananGandhi1810/cloudflare-ai-dart environment: diff --git a/test/cloudflare_ai_test.dart b/test/cloudflare_ai_test.dart index b3c0d8c..3b208c0 100644 --- a/test/cloudflare_ai_test.dart +++ b/test/cloudflare_ai_test.dart @@ -1,12 +1,13 @@ import 'dart:io'; import 'dart:typed_data'; import 'package:cloudflare_ai/cloudflare_ai.dart'; +import 'package:cloudflare_ai/src/text_chat/text_chat.dart'; import 'package:test/test.dart'; void main() { final env = Platform.environment; - String accountId = env['ACCOUNTID'] ?? "1cce28ed87f1c4865e511c1abd3f0101"; - String apiKey = env['APIKEY'] ?? "TxzHZU9gfVuHMVlC_M-pW1H7d9RpDE7OD_OIW2rY"; + String accountId = env['ACCOUNTID'] ?? ""; + String apiKey = env['APIKEY'] ?? ""; group('Text Generation:', () { test( "Gemma 7b IT", @@ -189,4 +190,21 @@ And so, with a newfound appreciation for the beauty of Earth and its inhabitants ), ); }); + + group("Text Chat: ", () { + test("Gemma 7B IT", () async { + TextChatModel model = TextChatModel( + accountId: accountId, + apiKey: apiKey, + model: TextChatModels.GEMMA_7B_IT, + ); + model.loadMessages([ + {"role": "user", "content": "Hello!"}, + {"role": "system", "content": "Hello! How may I help you?"}, + ]); + ChatMessage message = await model.chat("Who are you?"); + expect(message.content, isNotNull); + expect(message.role, Role.assistant); + }); + }); }