Skip to content

Commit

Permalink
Rewrite test
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan committed Mar 13, 2024
1 parent a568231 commit e2aa3a5
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 164 deletions.
132 changes: 132 additions & 0 deletions ai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package ai_test

import (
"context"
"fmt"
"io"
"os"
"testing"
"time"

"github.com/sunshineplan/ai"
"github.com/sunshineplan/ai/chatgpt"
"github.com/sunshineplan/ai/gemini"
)

func init() {
if proxy := os.Getenv("AI_PROXY"); proxy != "" {
ai.SetProxy(proxy)
}
}

func testChat(ai ai.AI, prompt string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println(prompt)
resp, err := ai.Chat(ctx, prompt)
if err != nil {
return err
}
fmt.Println(resp.Results())
fmt.Println("---")
return nil
}

func testChatStream(ai ai.AI, prompt string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println(prompt)
stream, err := ai.ChatStream(ctx, prompt)
if err != nil {
return err
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
if err == io.EOF {
break
}
return err
}
fmt.Println(resp.Results())
}
fmt.Println("---")
return nil
}

func testChatSession(ai ai.AI) error {
s := ai.ChatSession()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("Hello, I have 2 dogs in my house.")
resp, err := s.Chat(ctx, "Hello, I have 2 dogs in my house.")
if err != nil {
return err
}
fmt.Println(resp.Results())
ctx, cancel = context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("How many paws are in my house?")
stream, err := s.ChatStream(ctx, "How many paws are in my house?")
if err != nil {
return err
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
if err == io.EOF {
break
}
return err
}
fmt.Println(resp.Results())
}
fmt.Println("---")
fmt.Println("History")
for _, i := range s.History() {
fmt.Println(i.Role, ":", i.Content)
}
fmt.Println("---")
return nil
}

func TestGemini(t *testing.T) {
apiKey := os.Getenv("GEMINI_API_KEY")
if apiKey == "" {
return
}
gemini, err := gemini.New(apiKey)
if err != nil {
t.Fatal(err)
}
defer gemini.Close()
if err := testChat(gemini, "Who are you?"); err != nil {
t.Error(err)
}
if err := testChatStream(gemini, "Who am I?"); err != nil {
t.Error(err)
}
if err := testChatSession(gemini); err != nil {
t.Error(err)
}
}

func TestChatGPT(t *testing.T) {
apiKey := os.Getenv("CHATGPT_API_KEY")
if apiKey == "" {
return
}
chatgpt := chatgpt.New(apiKey)
defer chatgpt.Close()
if err := testChat(chatgpt, "Who are you?"); err != nil {
t.Error(err)
}
if err := testChatStream(chatgpt, "Who am I?"); err != nil {
t.Error(err)
}
if err := testChatSession(chatgpt); err != nil {
t.Error(err)
}
}
78 changes: 0 additions & 78 deletions chatgpt/chatgpt_test.go

This file was deleted.

86 changes: 0 additions & 86 deletions gemini/gemini_test.go

This file was deleted.

0 comments on commit e2aa3a5

Please sign in to comment.