Skip to content

Commit

Permalink
Added support for text chat
Browse files Browse the repository at this point in the history
  • Loading branch information
MananGandhi1810 committed May 20, 2024
1 parent 681ec26 commit 90b12de
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 8 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@

## 1.2.0

- Added support Language Translation APIs
- Made response results nullable, to handle errors better.
- Added support for Language Translation APIs
- Made response results nullable, to handle errors better.

## 1.3.0

- Added support for Text Chat APIs
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,4 @@ Supported Models:
- [x] Image Generation
- [x] Text Classification
- [x] Language Translation
- [x] Text Chat
30 changes: 28 additions & 2 deletions example/cloudflare_ai_example.dart
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import 'dart:io';
import 'dart:typed_data';
import 'package:cloudflare_ai/cloudflare_ai.dart';
import 'package:cloudflare_ai/src/text_chat/text_chat.dart';

void main() async {
String accountId = "";
String apiKey = "";
String accountId = "Your Account ID";
String apiKey = "Your API Key";

// Text Generation
// Initialize a TextGenerationModel
Expand Down Expand Up @@ -108,4 +109,29 @@ void main() async {
} else {
print(languageTranslationRes..errors.map((e) => e.toJson()).toList());
}

// Text Chat
// Initialize a TextChatModel
TextChatModel textChatModel = TextChatModel(
accountId: accountId,
apiKey: apiKey,
model: TextChatModels.GEMMA_7B_IT,
);

// Load any previous conversations
textChatModel.loadMessages([
{
"role": "user",
"content": "Hello!",
},
{
"role": "assistant",
"content": "Hello! How may I help you?",
},
]);

// Send a new message
ChatMessage chatRes = await textChatModel.chat("Who are you?");

print(chatRes.content);
}
39 changes: 39 additions & 0 deletions lib/src/text_chat/chat_message.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
enum Role {
system,
user,
assistant;
}

class ChatMessage {
late Role role;
late String content;

ChatMessage({
required this.role,
required this.content,
});

ChatMessage.fromJson(json) {
switch (json['role']) {
case 'system':
role = Role.assistant;
break;
case 'user':
role = Role.user;
break;
case 'assistant':
role = Role.assistant;
break;
default:
throw Exception("Invalid role");
}
content = json['content'];
}

Map toJson() {
Map res = {};
res['role'] = role.name;
res['content'] = content;
return res;
}
}
39 changes: 39 additions & 0 deletions lib/src/text_chat/models.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Defines the models, and specifies path to them in the API Endpoint
enum TextChatModels {
LLAMA_2_7B("@cf/meta/llama-2-7b-chat-fp16"),
LLAMA_2_7B_INT8("@cf/meta/llama-2-7b-chat-int8"),
MISTRAL_7B("@cf/mistral/mistral-7b-instruct-v0.1"),
CODE_LLAMA_7B("@hf/thebloke/codellama-7b-instruct-awq"),
CODE_LLAMA_2_13B("@hf/thebloke/llama-2-13b-chat-awq"),
ZEPHYR_7B("@hf/thebloke/zephyr-7b-beta-awq"),
MISTRAL_7B_AWQ_V01("@hf/thebloke/mistral-7b-instruct-v0.1-awq"),
MISTRAL_7B_AWQ_V02("@hf/mistral/mistral-7b-instruct-v0.2"),
OPENHERMES_MISTRAL_7B("@hf/thebloke/openhermes-2.5-mistral-7b-awq"),
NEURAL_CHAT_7B("@hf/thebloke/neural-chat-7b-v3-1-awq"),
LLAMA_GUARD_7B("@hf/thebloke/llamaguard-7b-awq"),
DEEPSEEK_CODER_6_7_BASE("@hf/thebloke/deepseek-coder-6.7b-base-awq"),
DEEPSEEK_CODER_6_7_INSTRUCT("@hf/thebloke/deepseek-coder-6.7b-instruct-awq"),
DEEPSEEK_MATH_7B_BASE("@cf/deepseek-ai/deepseek-math-7b-base"),
DEEPSEEK_MATH_7B_INSTRUCT("@cf/deepseek-ai/deepseek-math-7b-instruct"),
OPENCHAT_3_5("@cf/openchat/openchat-3.5-0106"),
PHI_2("@cf/phi/phi-2"),
TINYLAMA_1_1B("@cf/tinyllama/tinyllama-1.1b-chat-v1.0"),
DISCOLM_GERMAN_7B("@cf/thebloke/discolm-german-7b-v1-awq"),
QWEN_1_5_0_5B_CHAT("@cf/qwen/qwen-1.5.0.5b-chat"),
QWEN1_5_1_8B_CHAT("@cf/qwen/qwen1.5-1.8b-chat"),
QWEN_1_5_7B_CHAT_AWQ("@cf/qwen/qwen1.5-7b-chat-awq"),
QWEN_1_5_14B_CHAT_AWQ("@cf/qwen/qwen1.5-14b-chat-awq"),
FALCON_7B_INSTRUCT("@cf/tiiuae/falcon-7b-instruct"),
GEMMA_2B_IT_LORA("@cf/google/gemma-2b-it-lora"),
GEMMA_7B_IT("@hf/google/gemma-7b-it"),
GEMMA_7B_IT_LORA("@cf/google/gemma-7b-it-lora"),
HERMES_2_PRO_7B("@hf/nousresearch/hermes-2-pro-mistral-7b"),
LLAMA_2_7B_CHAT_HF_LORA("@cf/meta-llama/llama-2-7b-chat-hf-lora"),
LLAMA_3_8B_INSTRUCT("@hf/meta-llama/meta-llama-3-8b-instruct"),
UNA_CYBERTRON_7B_V2_BF16("@cf/fblgit/una-cybertron-7b-v2-bf16"),
STARLING_LM_7B_BETA("@hf/nexusflow/starling-lm-7b-beta"),
SQL_CODER_7B_2("@cf/defog/sqlcoder-7b-2");

const TextChatModels(this.value);
final String value;
}
52 changes: 52 additions & 0 deletions lib/src/text_chat/response.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Defines Repsonse received from the API
import '../models/error.dart';

class TextChatResponse {
late TextChatResult? result;
late bool success;
late List<ErrorModel> errors;
late List messages;

TextChatResponse({
required this.result,
required this.success,
errors,
messages,
}) {
this.errors = errors ?? [];
this.messages = messages ?? [];
}

TextChatResponse.fromJson(Map<String, dynamic> json) {
result = TextChatResult.fromJson(json['result']);
success = json['success'];
errors =
(json['errors'] as List).map((e) => ErrorModel.fromJson(e)).toList();
messages = json['messages'];
}

Map<String, dynamic> toJson() {
final Map<String, dynamic> data = <String, dynamic>{};
data['result'] = result?.toJson() ?? {};
data['success'] = success;
data['errors'] = errors.map((e) => e.toJson()).toList();
data['messages'] = messages;
return data;
}
}

class TextChatResult {
String? response;

TextChatResult({this.response});

TextChatResult.fromJson(Map<String, dynamic> json) {
response = json['response'];
}

Map<String, dynamic> toJson() {
final Map<String, dynamic> data = <String, dynamic>{};
data['response'] = response;
return data;
}
}
66 changes: 66 additions & 0 deletions lib/src/text_chat/text_chat.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import 'chat_message.dart';
import 'models.dart';
import 'response.dart';
import '../services/network_service.dart';
export 'models.dart';
export 'response.dart';
export 'chat_message.dart';

class TextChatModel {
late String accountId;
late String apiKey;
late TextChatModels model;
late bool raw;
List<ChatMessage> _messages = [];
NetworkService networkService = NetworkService();
late String baseUrl;

TextChatModel({
required this.accountId,
required this.apiKey,
required this.model,
this.raw = true,
}) {
baseUrl = "https://api.cloudflare.com/client/v4/accounts/$accountId/ai/run";
if (accountId.trim() == "") {
throw Exception("Account ID cannot be empty");
}
if (apiKey.trim() == "") {
throw Exception("API Key cannot be empty");
}
}

// Get all chat messages
List<ChatMessage> get allMessages => _messages;

// Load a previous chat
void loadMessages(List<Map<String, dynamic>> messages) {
_messages =
messages.map((message) => ChatMessage.fromJson(message)).toList();
}

// Asynchronous function which returns the classification labels with their confidence of the text
Future<ChatMessage> chat(String message) async {
_messages.add(
ChatMessage(
role: Role.user,
content: message,
),
);
Map<String, dynamic> res =
await networkService.post("$baseUrl/${model.value}", apiKey, {
"messages": _messages,
});
TextChatResponse response = TextChatResponse.fromJson(res['data']);
if (!response.success || response.result == null) {
throw Exception(response.errors);
}
_messages.add(
ChatMessage(
role: Role.assistant,
content: response.result?.response ?? "",
),
);
return _messages.last;
}
}
2 changes: 1 addition & 1 deletion lib/src/text_classification/text_classification.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import 'package:cloudflare_ai/src/text_classification/models.dart';
import 'models.dart';
import 'response.dart';
import '../services/network_service.dart';
export 'models.dart';
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: cloudflare_ai
description: This is a Dart package for Cloudflare Workers AI. It currently supports Text Generation and Image Generation, Text Summarization, and Image Generation Models.
version: 1.2.0
version: 1.3.0
repository: https://github.com/MananGandhi1810/cloudflare-ai-dart

environment:
Expand Down
22 changes: 20 additions & 2 deletions test/cloudflare_ai_test.dart
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import 'dart:io';
import 'dart:typed_data';
import 'package:cloudflare_ai/cloudflare_ai.dart';
import 'package:cloudflare_ai/src/text_chat/text_chat.dart';
import 'package:test/test.dart';

void main() {
final env = Platform.environment;
String accountId = env['ACCOUNTID'] ?? "1cce28ed87f1c4865e511c1abd3f0101";
String apiKey = env['APIKEY'] ?? "TxzHZU9gfVuHMVlC_M-pW1H7d9RpDE7OD_OIW2rY";
String accountId = env['ACCOUNTID'] ?? "";
String apiKey = env['APIKEY'] ?? "";
group('Text Generation:', () {
test(
"Gemma 7b IT",
Expand Down Expand Up @@ -189,4 +190,21 @@ And so, with a newfound appreciation for the beauty of Earth and its inhabitants
),
);
});

group("Text Chat: ", () {
test("Gemma 7B IT", () async {
TextChatModel model = TextChatModel(
accountId: accountId,
apiKey: apiKey,
model: TextChatModels.GEMMA_7B_IT,
);
model.loadMessages([
{"role": "user", "content": "Hello!"},
{"role": "system", "content": "Hello! How may I help you?"},
]);
ChatMessage message = await model.chat("Who are you?");
expect(message.content, isNotNull);
expect(message.role, Role.assistant);
});
});
}

0 comments on commit 90b12de

Please sign in to comment.