From 0e1e38326adfe28bc5aa58f302233a3bd511016c Mon Sep 17 00:00:00 2001 From: alpayariyak Date: Thu, 13 Jun 2024 17:48:05 +0000 Subject: [PATCH] Fix deprecated max_context_len_to_capture engine argument --- README.md | 2 +- src/config.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a001e7f..0ccbdc6 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ Below is a summary of the available RunPod Worker images, categorized by image s | `BLOCK_SIZE` | `16` | `8`, `16`, `32` |Token block size for contiguous chunks of tokens. | | `SWAP_SPACE` | `4` | `int` |CPU swap space size (GiB) per GPU. | | `ENFORCE_EAGER` | `0` | boolean as `int` |Always use eager-mode PyTorch. If False(`0`), will use eager mode and CUDA graph in hybrid for maximal performance and flexibility. | -| `MAX_CONTEXT_LEN_TO_CAPTURE` | `8192` | `int` |Maximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode.| +| `MAX_SEQ_LEN_TO_CAPTURE` | `8192` | `int` |Maximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode.| | `DISABLE_CUSTOM_ALL_REDUCE` | `0` | `int` |Enables or disables custom all reduce. | **Streaming Batch Size Settings**: | `DEFAULT_BATCH_SIZE` | `50` | `int` |Default and Maximum batch size for token streaming to reduce HTTP calls. | diff --git a/src/config.py b/src/config.py index 9d771cc..67b9836 100644 --- a/src/config.py +++ b/src/config.py @@ -47,11 +47,16 @@ def _initialize_config(self): "kv_cache_dtype": os.getenv("KV_CACHE_DTYPE"), "block_size": int(os.getenv("BLOCK_SIZE")) if os.getenv("BLOCK_SIZE") else None, "swap_space": int(os.getenv("SWAP_SPACE")) if os.getenv("SWAP_SPACE") else None, - "max_context_len_to_capture": int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE")) if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE") else None, + "max_seq_len_to_capture": int(os.getenv("MAX_SEQ_LEN_TO_CAPTURE")) if os.getenv("MAX_SEQ_LEN_TO_CAPTURE") else None, "disable_custom_all_reduce": get_int_bool_env("DISABLE_CUSTOM_ALL_REDUCE", False), "enforce_eager": get_int_bool_env("ENFORCE_EAGER", False) } if args["kv_cache_dtype"] == "fp8_e5m2": args["kv_cache_dtype"] = "fp8" logging.warning("Using fp8_e5m2 is deprecated. Please use fp8 instead.") + if os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE"): + args["max_seq_len_to_capture"] = int(os.getenv("MAX_CONTEXT_LEN_TO_CAPTURE")) + logging.warning("Using MAX_CONTEXT_LEN_TO_CAPTURE is deprecated. Please use MAX_SEQ_LEN_TO_CAPTURE instead.") + + return {k: v for k, v in args.items() if v not in [None, ""]}