Skip to content

Commit

Permalink
feat: more type annotations for the functions (ShishirPatil#283)
Browse files Browse the repository at this point in the history
related to ShishirPatil#282
  • Loading branch information
UponTheSky authored Mar 24, 2024
1 parent a8fe4f8 commit 84ab3b6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
15 changes: 11 additions & 4 deletions raft/eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any
import string
import re
from openai import OpenAI
from openai import OpenAI
from openai import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletion
import multiprocessing as mp
import time
import argparse
Expand All @@ -16,7 +18,9 @@
api_key=api_key,
)

def get_openai_response(message):
def get_openai_response(
message: list[ChatCompletionMessageParam]
) -> str | ChatCompletion | None :
response = client.chat.completions.create(
messages=message,
model=model_name,
Expand All @@ -28,14 +32,17 @@ def get_openai_response(message):
print(e)
return response

def get_answer(input_json):
def get_answer(input_json: dict[str, Any]) -> dict[str, Any]:
message = [{"role": "user", "content": input_json['instruction']}]
result = get_openai_response(message)
input_json['model_answer'] = result
return input_json


def write_result_to_file(result, write_file_name):
def write_result_to_file(
result: dict[str, Any],
write_file_name: str
) -> None:
global file_write_lock
with file_write_lock:
with open(write_file_name, "a") as outfile:
Expand Down
34 changes: 25 additions & 9 deletions raft/raft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal, Any
import argparse
from openai import OpenAI
from datasets import Dataset, load_dataset
Expand All @@ -8,7 +9,9 @@
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings

def get_args() -> any:
DocType = Literal["api", "pdf", "json", "txt"]

def get_args() -> argparse.Namespace:
"""
Parses and returns the arguments specified by the user's command
"""
Expand All @@ -26,7 +29,13 @@ def get_args() -> any:
args = parser.parse_args()
return args

def get_chunks(file_path: str, doctype="pdf", chunk_size=512, openai_key=None) -> list[str]:

def get_chunks(
file_path: str,
doctype: DocType = "pdf",
chunk_size: int = 512,
openai_key: str | None = None
) -> list[str]:
"""
Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
`chunk_size`, and returns the chunks.
Expand Down Expand Up @@ -70,7 +79,7 @@ def get_chunks(file_path: str, doctype="pdf", chunk_size=512, openai_key=None) -

return chunks

def generate_instructions(api_call, x=5) -> list[str]:
def generate_instructions(api_call: Any, x=5) -> list[str]:
"""
Generates `x` questions / use cases for `api_call`. Used when the input document is of type `api`.
"""
Expand All @@ -91,7 +100,7 @@ def generate_instructions(api_call, x=5) -> list[str]:

return queries

def generate_instructions_gen(chunk, x=5) -> list[str]:
def generate_instructions_gen(chunk: Any, x: int = 5) -> list[str]:
"""
Generates `x` questions / use cases for `chunk`. Used when the input document is of general types
`pdf`, `json`, or `txt`.
Expand All @@ -111,7 +120,7 @@ def generate_instructions_gen(chunk, x=5) -> list[str]:

return queries

def strip_str(s) -> str:
def strip_str(s: str) -> str:
"""
Helper function for helping format strings returned by GPT-4.
"""
Expand All @@ -127,7 +136,7 @@ def strip_str(s) -> str:
r += 2
return s[l:min(r, len(s))]

def encode_question(question, api) -> list[str]:
def encode_question(question: str, api: Any) -> list[str]:
"""
Encode multiple prompt instructions into a single string for the `api` case.
"""
Expand All @@ -138,7 +147,7 @@ def encode_question(question, api) -> list[str]:
prompts.append({"role": "user", "content": prompt})
return prompts

def encode_question_gen(question, chunk) -> list[str]:
def encode_question_gen(question: str, chunk: Any) -> list[str]:
"""
Encode multiple prompt instructions into a single string for the general case (`pdf`, `json`, or `txt`).
"""
Expand All @@ -156,7 +165,7 @@ def encode_question_gen(question, chunk) -> list[str]:
prompts.append({"role": "user", "content": prompt})
return prompts

def generate_label(question, context, doctype="pdf") -> str:
def generate_label(question: str, context: Any, doctype: DocType = "pdf") -> str | None:
"""
Generates the label / answer to `question` using `context` and GPT-4.
"""
Expand All @@ -170,7 +179,14 @@ def generate_label(question, context, doctype="pdf") -> str:
response = response.choices[0].message.content
return response

def add_chunk_to_dataset(chunks: list, chunk: str, doctype: str = "api", x: int = 5, num_distract: int = 3, p: float = 1.0):
def add_chunk_to_dataset(
chunks: list[str],
chunk: str,
doctype: DocType = "api",
x: int = 5,
num_distract: int = 3,
p: float = 1.0
) -> None:
"""
Given a chunk, create {Q, A, D} triplets and add them to the dataset.
"""
Expand Down

0 comments on commit 84ab3b6

Please sign in to comment.