diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9f8d215b..f419fa39 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,4 +1,4 @@ -use clap::Parser; +use clap::{Parser, ValueEnum}; use serde::Deserialize; use std::env; use std::ffi::OsString; @@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection}; mod env_runtime; +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Quantization { + Bitsandbytes, + Gptq, +} + +impl std::fmt::Display for Quantization { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Quantization::Bitsandbytes => { + write!(f, "bitsandbytes") + } + Quantization::Gptq => { + write!(f, "gptq") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -46,10 +66,10 @@ struct Args { #[clap(long, env)] num_shard: Option, - /// Wether you want the model to be quantized or not. This will use bitsandbytes for - /// quantization on the fly. - #[clap(long, env)] - quantize: bool, + /// Wether you want the model to be quantized or not. This will use `bitsandbytes` for + /// quantization on the fly, or `gptq`. + #[clap(long, env, value_enum)] + quantize: Option, /// The maximum amount of concurrent requests for this particular deployment. /// Having a low limit will refuse clients requests instead of having them @@ -218,7 +238,7 @@ enum ShardStatus { fn shard_manager( model_id: String, revision: Option, - quantize: bool, + quantize: Option, uds_path: String, rank: usize, world_size: usize, @@ -257,8 +277,9 @@ fn shard_manager( shard_argv.push("--sharded".to_string()); } - if quantize { - shard_argv.push("--quantize".to_string()) + if let Some(quantize) = quantize { + shard_argv.push("--quantize".to_string()); + shard_argv.push(quantize.to_string()) } // Model optional revision diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 92482a94..9564bb07 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -5,17 +5,23 @@ import typer from pathlib import Path from loguru import logger from typing import Optional +from enum import Enum app = typer.Typer() +class Quantization(str, Enum): + bitsandbytes = "bitsandbytes" + gptq = "gptq" + + @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, - quantize: bool = False, + quantize: Optional[Quantization] = None, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, @@ -55,7 +61,7 @@ def serve( if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) - server.serve(model_id, revision, sharded, quantize, uds_path) + server.serve(model_id, revision, sharded, quantize.value, uds_path) @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 221c9139..e02be3de 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -91,7 +91,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str] ) -> Model: if "facebook/galactica" in model_id: if sharded: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index de9b22da..b4e7985c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -105,7 +105,7 @@ class FastLinear(nn.Linear): self.bnb_linear = None def prepare_weights(self, quantize: bool = False): - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -129,8 +129,12 @@ class FastLinear(nn.Linear): # Delete reference to data self.weight = None self.bias = None - else: + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: if self.quantized: diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 105ff519..d191ef96 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM): class FlashLlamaSharded(FlashLlama): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.past_pad = None self.process_group, self.rank, self.world_size = initialize_torch_distributed() diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 70f08ed7..d715207b 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -100,14 +100,14 @@ def serve( model_id: str, revision: Optional[str], sharded: bool, - quantize: bool, + quantize: Optional[str], uds_path: Path, ): async def serve_inner( model_id: str, revision: Optional[str], sharded: bool = False, - quantize: bool = False, + quantize: Optional[str] = None, ): unix_socket_template = "unix://{}-{}" if sharded: