From 9b7d9ae7da88a42828605fa59f8873ba7d5f558d Mon Sep 17 00:00:00 2001 From: Tradunsky Date: Sun, 24 Dec 2023 02:18:26 -0500 Subject: [PATCH] Add CPU device-id support --- src/insanely_fast_whisper/cli.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/insanely_fast_whisper/cli.py b/src/insanely_fast_whisper/cli.py index 36c6ead..9a29dba 100644 --- a/src/insanely_fast_whisper/cli.py +++ b/src/insanely_fast_whisper/cli.py @@ -19,7 +19,7 @@ required=False, default="0", type=str, - help='Device ID for your GPU. Just pass the device number when using CUDA, or "mps" for Macs with Apple Silicon. (default: "0")', + help='Device ID for your GPU. Just pass the device number when using CUDA, or "mps" for Macs with Apple Silicon or "cpu". (default: "0")', ) parser.add_argument( "--transcript-path", @@ -91,11 +91,13 @@ def main(): args = parser.parse_args() + dtype = torch.float32 if args.device_id == "cpu" else torch.float16 + pipe = pipeline( "automatic-speech-recognition", model=args.model_name, - torch_dtype=torch.float16, - device="mps" if args.device_id == "mps" else f"cuda:{args.device_id}", + torch_dtype=dtype, + device=f"cuda:{args.device_id}" if args.device_id.isnumeric() else args.device_id, model_kwargs={"attn_implementation": "flash_attention_2"} if args.flash else {"attn_implementation": "sdpa"}, )