Skip to content

Commit

Permalink
Update subgen.py
Browse files Browse the repository at this point in the history
Added ASR webhook for Bazarr
  • Loading branch information
McCloudS authored Oct 31, 2023
1 parent 2b4956c commit 59802f9
Showing 1 changed file with 193 additions and 22 deletions.
215 changes: 193 additions & 22 deletions subgen/subgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import queue
import logging
import gc
import io
from array import array
from typing import Union, Any
from typing import BinaryIO, Union, Any
import random

# List of packages to install
packages_to_install = [
Expand All @@ -20,22 +22,18 @@
'faster-whisper',
'uvicorn',
'python-multipart',
'whisper',
# Add more packages as needed
]

for package in packages_to_install:
print(f"Installing {package}...")
try:
subprocess.run(['pip3', 'install', package], check=True)
print(f"{package} has been successfully installed.")
except subprocess.CalledProcessError as e:
print(f"Failed to install {package}: {e}")

from fastapi import FastAPI, File, UploadFile, Query, Header, Body, Form, Request
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.responses import StreamingResponse, RedirectResponse
import numpy as np
import stable_whisper
import requests
import av
import ffmpeg
import whisper

def convert_to_bool(in_bool):
if isinstance(in_bool, bool):
Expand Down Expand Up @@ -169,6 +167,86 @@ def receive_emby_webhook(

return ""

@app.post("/asr")
async def asr(
task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None),
initial_prompt: Union[str, None] = Query(default=None), #not used by Bazarr
audio_file: UploadFile = File(...),
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), #not used by Bazarr/always False
output: Union[str, None] = Query(default="srt", enum=["txt", "vtt", "srt", "tsv", "json"]),
word_timestamps: bool = Query(default=False, description="Word level timestamps") #not used by Bazarr
):
try:
print(f"Transcribing file from Bazarr/ASR webhook")
start_time = time.time()
start_model()

#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
result = model.transcribe_stable(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0, task=task, input_sr=16000)
elapsed_time = time.time() - start_time
minutes, seconds = divmod(int(elapsed_time), 60)
print(f"Bazarr transcription is completed, it took {minutes} minutes and {seconds} seconds to complete.")
except Exception as e:
print(f"Error processing or transcribing Bazarr {audio_file.filename}: {e}")
files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
return StreamingResponse(
iter(result.to_srt_vtt(filepath = None, word_level=word_level_highlight)),
media_type="text/plain",
headers={
'Source': 'Transcribed using stable-ts, faster-whisper from Subgen!',
})

@app.post("/detect-language")
async def detect_language(
audio_file: UploadFile = File(...),
#encode: bool = Query(default=True, description="Encode audio first through ffmpeg") # This is always false from Bazarr
):
start_model()

#give the 'process' a random name so mutliple Bazaar transcribes can operate at the same time.
random_name = random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", k=6)
files_to_transcribe.insert(0, f"Bazarr-detect-langauge-{random_name}")
detected_lang_code = model.transcribe_stable(whisper.pad_or_trim(np.frombuffer(audio_file.file.read(), np.int16).flatten().astype(np.float32) / 32768.0), input_sr=16000).language

files_to_transcribe.remove(f"Bazarr-detect-langauge-{random_name}")
delete_model()
return {"detected_language": get_lang_pair(whisper_languages, detected_lang_code), "language_code": detected_lang_code}

def start_model():
global model
if model is None:
logging.debug("Model was purged, need to re-create")
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions)

def delete_model():
if len(files_to_transcribe) == 0:
global model
logging.debug("Queue is empty, clearing/releasing VRAM")
del model
gc.collect()

def get_lang_pair(whisper_languages, key):
"""Returns the other side of the pair in the Whisper languages dictionary.
Args:
whisper_languages: A dictionary of Whisper languages.
key: The key to look up in the dictionary.
Returns:
The other side of the pair in the Whisper languages dictionary, or None if the
key is not found in the dictionary.
"""

other_side = whisper_languages.get(key)
if other_side is None:
return key
else:
return whisper_languages[other_side]

def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True) -> None:
"""Generates subtitles for a video file.
Expand All @@ -177,7 +255,6 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
transcription_or_translation: The type of transcription or translation to perform.
front: Whether to add the file to the front of the transcription queue.
"""
global model

try:
if not is_video_file(file_path):
Expand All @@ -202,9 +279,7 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
print(f"{len(files_to_transcribe)} files in the queue for transcription")
print(f"Transcribing file: {os.path.basename(file_path)}")
start_time = time.time()
if model is None:
logging.debug("Model was purged, need to re-create")
model = stable_whisper.load_faster_whisper(whisper_model, download_root=model_location, device=transcribe_device, cpu_threads=whisper_threads, num_workers=concurrent_transcriptions)
start_model()

result = model.transcribe_stable(file_path, task=transcribe_or_translate_str)
result.to_srt_vtt(file_path.rsplit('.', 1)[0] + subextension, word_level=word_level_highlight)
Expand All @@ -216,15 +291,9 @@ def gen_subtitles(file_path: str, transcribe_or_translate_str: str, front=True)
print(f"File {os.path.basename(file_path)} is already in the transcription list. Skipping.")

except Exception as e:
print(f"Error processing or transcribing {file_path}: {e}")
print(f"Error processing or transcribing {video_file_path}: {e}")
finally:
if len(files_to_transcribe) == 0:
logging.debug("Queue is empty, clearing/releasing VRAM")
try:
del model
except Exception as e:
None
gc.collect()
delete_model()

def has_subtitle_language(video_file, target_language):
try:
Expand Down Expand Up @@ -346,6 +415,108 @@ def transcribe_existing():
transcribe_folders = transcribe_folders.split(",")
transcribe_existing()

whisper_languages = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}

print("Starting webhook!")
if __name__ == "__main__":
import uvicorn
Expand Down

0 comments on commit 59802f9

Please sign in to comment.