Skip to content

Commit

Permalink
update app
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel CHUA committed Oct 20, 2024
1 parent 14ff1d7 commit 506f934
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 58 deletions.
4 changes: 1 addition & 3 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@

# Fireworks API-related constants
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
FIREWORKS_MAX_TOKENS = 16_384
FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
FIREWORKS_TEMPERATURE = 0.1
FIREWORKS_JSON_RETRY_ATTEMPTS = 3

# MeloTTS
MELO_API_NAME = "/synthesize"
Expand Down Expand Up @@ -80,7 +78,7 @@
Generate Podcasts from PDFs using open-source AI.
Built with:
- [Llama 3.1 405B 🦙](https://huggingface.co/meta-llama/Llama-3.1-405B) via [Fireworks AI 🎆](https://fireworks.ai/)
- [Llama 3.1 405B 🦙](https://huggingface.co/meta-llama/Llama-3.1-405B) via [Fireworks AI 🎆](https://fireworks.ai/) and [Instructor 📐](https://github.com/instructor-ai/instructor)
- [MeloTTS 🐚](https://huggingface.co/myshell-ai/MeloTTS-English)
- [Bark 🐶](https://huggingface.co/suno/bark)
- [Jina Reader 🔍](https://jina.ai/reader/)
Expand Down
12 changes: 9 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ click==8.1.7
contourpy==1.3.0
cycler==0.12.1
distro==1.9.0
docstring_parser==0.16
einops==0.8.0
encodec==0.1.1
fastapi==0.115.0
ffmpy==0.4.0
filelock==3.16.1
fireworks-ai==0.15.6
fonttools==4.54.1
frozenlist==1.4.1
fsspec==2024.9.0
Expand All @@ -28,10 +30,13 @@ granian==1.4.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.2
httpx-sse==0.4.0
httpx-ws==0.6.2
huggingface-hub==0.25.1
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.4.5
instructor==1.6.2
Jinja2==3.1.4
jiter==0.5.0
jmespath==1.0.1
Expand All @@ -55,8 +60,8 @@ pandas==2.2.3
pillow==10.4.0
promptic==0.7.5
psutil==5.9.8
pydantic==2.7.0
pydantic_core==2.18.1
pydantic==2.9.2
pydantic_core==2.23.4
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.4
Expand Down Expand Up @@ -85,7 +90,7 @@ spaces==0.30.2
starlette==0.38.6
suno-bark @ git+https://github.com/suno-ai/bark.git@f4f32d4cd480dfec1c245d258174bc9bde3c2148
sympy==1.13.3
tenacity==8.3.0
tenacity==9.0.0
tiktoken==0.7.0
tokenizers==0.20.0
tomlkit==0.12.0
Expand All @@ -100,5 +105,6 @@ urllib3==2.2.3
uvicorn==0.31.0
uvloop==0.18.0
websockets==12.0
wsproto==1.2.0
yarl==1.13.1
zipp==3.20.2
67 changes: 15 additions & 52 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@
- call_llm: Call the LLM with the given prompt and dialogue format.
- parse_url: Parse the given URL and return the text content.
- generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
- _use_suno_model: Generate advanced audio using Bark.
- _use_melotts_api: Generate audio using TTS model.
- _get_melo_tts_params: Get TTS parameters based on speaker and language.
"""

# Standard library imports
import time
from typing import Any, Union

# Third-party imports
import instructor
import requests
from bark import SAMPLE_RATE, generate_audio, preload_models
from fireworks.client import Fireworks
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError
from scipy.io.wavfile import write as write_wav

# Local imports
from constants import (
FIREWORKS_API_KEY,
FIREWORKS_BASE_URL,
FIREWORKS_MODEL_ID,
FIREWORKS_MAX_TOKENS,
FIREWORKS_TEMPERATURE,
FIREWORKS_JSON_RETRY_ATTEMPTS,
MELO_API_NAME,
MELO_TTS_SPACES_ID,
MELO_RETRY_ATTEMPTS,
Expand All @@ -38,8 +39,11 @@
)
from schema import ShortDialogue, MediumDialogue

# Initialize clients
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
# Initialize Fireworks client, with Instructor patch
fw_client = Fireworks(api_key=FIREWORKS_API_KEY)
fw_client = instructor.from_fireworks(fw_client)

# Initialize Hugging Face client
hf_client = Client(MELO_TTS_SPACES_ID)

# Download and load all models for Bark
Expand All @@ -53,51 +57,13 @@ def generate_script(
) -> Union[ShortDialogue, MediumDialogue]:
"""Get the dialogue from the LLM."""

# Call the LLM
response = call_llm(system_prompt, input_text, output_model)
response_json = response.choices[0].message.content

# Validate the response
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
try:
first_draft_dialogue = output_model.model_validate_json(response_json)
break
except ValidationError as e:
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
) from e
error_message = (
f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
)
# Re-call the LLM with the error message
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
response_json = response.choices[0].message.content
first_draft_dialogue = output_model.model_validate_json(response_json)
# Call the LLM for the first time
first_draft_dialogue = call_llm(system_prompt, input_text, output_model)

# Call the LLM a second time to improve the dialogue
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue.model_dump_json()}."
final_dialogue = call_llm(system_prompt_with_dialogue, "Please improve the dialogue. Make it more natural and engaging.", output_model)

# Validate the response
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
try:
response = call_llm(
system_prompt_with_dialogue,
"Please improve the dialogue. Make it more natural and engaging.",
output_model,
)
final_dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
break
except ValidationError as e:
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
) from e
error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
return final_dialogue


Expand All @@ -111,10 +77,7 @@ def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
model=FIREWORKS_MODEL_ID,
max_tokens=FIREWORKS_MAX_TOKENS,
temperature=FIREWORKS_TEMPERATURE,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
response_model=dialogue_format,
)
return response

Expand Down

0 comments on commit 506f934

Please sign in to comment.