diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 9ff72569..8dd1e6e8 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -29,6 +29,12 @@ class Dtype(str, Enum): bloat16 = "bfloat16" +class KVDtype(str, Enum): + auto = "auto" + fp8 = "fp8" + fp8_e5m2 = "fp8_e5m2" + + @app.command() def serve( model_id: str, @@ -37,7 +43,7 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, - kv_cache_dtype: str = "auto", + kv_cache_dtype: KVDtype = "auto", trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO",