diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index f528a430..ac86c211 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch): class BLOOM(CausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): super(BLOOM, self).__init__( model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 ) @@ -61,7 +66,10 @@ class BLOOM(CausalLM): class BLOOMSharded(BLOOM): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -111,7 +119,7 @@ class BLOOMSharded(BLOOM): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -165,7 +173,7 @@ class BLOOMSharded(BLOOM): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -215,9 +223,14 @@ class BLOOMSharded(BLOOM): return linear module.linear = replace_linear(state) - - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "word_embeddings.weight": diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 26a9a661..adcf4a47 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -447,7 +447,7 @@ class CausalLM(Model): self, model_id: str, revision: Optional[str] = None, - quantize: bool = False, + quantize: Optional[str] = None, decode_buffer: int = 3, ): if torch.cuda.is_available(): @@ -468,7 +468,7 @@ class CausalLM(Model): revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", ).eval() tokenizer.pad_token_id = ( self.model.config.pad_token_id diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index cc9b292f..bf61e2a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -92,8 +92,8 @@ class FastLinear(nn.Linear): self.quantized = False self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): - if quantize: + def prepare_weights(self, quantize: Optional[str] = None): + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -117,8 +117,12 @@ class FastLinear(nn.Linear): # Delete reference to data self.weight = None self.bias = None - else: + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: if self.quantized: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 71182f8d..7604ea78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -67,8 +67,8 @@ class FastLinear(nn.Linear): self.quantized = False self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): - if quantize: + def prepare_weights(self, quantize: Optional[str] = None): + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -92,8 +92,12 @@ class FastLinear(nn.Linear): # Delete reference to data self.weight = None self.bias = None - else: + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: if self.quantized: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 413866d1..2d06d947 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -364,7 +364,7 @@ class FlashCausalLM(Model): model_cls: Type[PreTrainedModel], model_id: str, revision: Optional[str] = None, - quantize: bool = False, + quantize: Optional[str] = None, decode_buffer: int = 3, ): self.past_pad = None @@ -382,7 +382,7 @@ class FlashCausalLM(Model): model_id, revision=revision, torch_dtype=dtype, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", ) .eval() .to(device) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 2577f1b1..6ab82ee6 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -193,7 +193,10 @@ class Galactica(OPT): class GalacticaSharded(Galactica): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -242,7 +245,7 @@ class GalacticaSharded(Galactica): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -297,7 +300,7 @@ class GalacticaSharded(Galactica): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -347,9 +350,14 @@ class GalacticaSharded(Galactica): return linear module.linear = replace_linear(state) - - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "model.decoder.embed_tokens.weight": diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index e73a3c82..d7a95222 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -32,7 +32,10 @@ except Exception as e: class GPTNeoxSharded(CausalLM): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -81,7 +84,7 @@ class GPTNeoxSharded(CausalLM): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -146,7 +149,7 @@ class GPTNeoxSharded(CausalLM): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -196,9 +199,14 @@ class GPTNeoxSharded(CausalLM): return linear module.linear = replace_linear(state) - - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index a7b09a82..0370ffa1 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -14,7 +14,12 @@ EOD = "<|endoftext|>" class SantaCoder(CausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 @@ -46,7 +51,7 @@ class SantaCoder(CausalLM): model_id, revision=revision, torch_dtype=dtype, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", trust_remote_code=True, # required ) .to(device) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 4ac5ed3c..c6ae6c98 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -501,7 +501,7 @@ class Seq2SeqLM(Model): self, model_id: str, revision: Optional[str] = None, - quantize: bool = False, + quantize: Optional[str] = None, decode_buffer: int = 3, ): if torch.cuda.is_available(): @@ -519,7 +519,7 @@ class Seq2SeqLM(Model): revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", ).eval() tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 9e8c3c4c..7e7e5546 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -32,7 +32,10 @@ except Exception as e: class T5Sharded(Seq2SeqLM): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -81,7 +84,7 @@ class T5Sharded(Seq2SeqLM): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -152,7 +155,7 @@ class T5Sharded(Seq2SeqLM): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -203,8 +206,14 @@ class T5Sharded(Seq2SeqLM): module.linear = replace_linear(state) - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor