diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 8221068b..fc4a59b9 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,10 +2,6 @@ import os import math import torch from torch import nn - -# Inverse dim formula to find dim based on number of rotations - - from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 0d313e15..9516182c 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -69,12 +69,13 @@ class TensorParallelHead(SuperLayer): # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) if config.quantize in ["gptq", "awq", "eetq", "marlin"]: - pass + # Local variable `quantize` is assigned to but never used + quantize = None # noqa F841 # See above, exl2 LM head can be quantized or not. elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight): - pass + quantize = None # noqa F841 else: - pass + quantize = config.quantize # noqa F841 return TensorParallelHead( get_linear(weight, bias=None), diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e6afe9e6..3dc24159 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,3 +1,6 @@ +# ruff: noqa: F821 +# the above line disables the `undefined-name` rule for the model type variables + import torch import enum import os @@ -298,12 +301,10 @@ class ModelType(enum.Enum): "multimodal": True, } - @classmethod - def from_str(cls, model_type: str) -> "ModelType": - for model in cls: - if model.value["type"] == model_type: - return model - raise ValueError(f"Unknown model type {model_type}") + +__GLOBALS = locals() +for data in ModelType: + __GLOBALS[data.name] = data.value["type"] def get_model( @@ -492,10 +493,7 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - # convert model_type to ModelType enum - model_type = ModelType.from_str(model_type) - - if model_type == ModelType.DEEPSEEK_V2: + if model_type == DEEPSEEK_V2: if FLASH_ATTENTION: head_size = max( config_dict.get("qk_nope_dim", 128) @@ -528,7 +526,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == ModelType.MAMBA: + elif model_type == MAMBA: return Mamba( model_id, revision, @@ -552,8 +550,8 @@ def get_model( ) if ( - model_type == ModelType.GPT_BIGCODE - or model_type == ModelType.GPT2 + model_type == GPT_BIGCODE + or model_type == GPT2 and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: @@ -583,7 +581,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.BLOOM: + if model_type == BLOOM: return CausalLM( model_id=model_id, model_class=BloomForCausalLM, @@ -594,7 +592,7 @@ def get_model( trust_remote_code=trust_remote_code, batch_class=BloomCausalLMBatch, ) - elif model_type == ModelType.MPT: + elif model_type == MPT: return CausalLM( model_id=model_id, model_class=MPTForCausalLM, @@ -605,7 +603,7 @@ def get_model( trust_remote_code=trust_remote_code, batch_class=CausalLMBatchKeysLast, ) - elif model_type == ModelType.GPT2: + elif model_type == GPT2: if FLASH_ATTENTION: try: return FlashCausalLM( @@ -640,7 +638,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == ModelType.GPT_NEOX: + elif model_type == GPT_NEOX: if FLASH_ATTENTION: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( GPTNeoXConfig, @@ -677,7 +675,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == ModelType.PHI: + elif model_type == PHI: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -716,11 +714,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif ( - model_type == ModelType.LLAMA - or model_type == ModelType.BAICHUAN - or model_type == ModelType.PHI3 - ): + elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: print(f">>> model_type: {model_type}") if FLASH_ATTENTION: return FlashCausalLM( @@ -744,7 +738,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == ModelType.GEMMA: + if model_type == GEMMA: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -769,7 +763,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == ModelType.GEMMA2: + elif model_type == GEMMA2: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -795,7 +789,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.COHERE: + if model_type == COHERE: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -819,7 +813,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.DBRX: + if model_type == DBRX: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -846,7 +840,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type in ["RefinedWeb", "RefinedWebModel", ModelType.FALCON]: + if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: if FLASH_ATTENTION: if config_dict.get("alibi", False): @@ -894,7 +888,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.MISTRAL: + if model_type == MISTRAL: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -918,7 +912,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.MIXTRAL: + if model_type == MIXTRAL: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -942,7 +936,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.STARCODER2: + if model_type == STARCODER2: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -968,7 +962,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.QWEN2: + if model_type == QWEN2: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -992,7 +986,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.OPT: + if model_type == OPT: return CausalLM( model_id=model_id, model_class=OPTForCausalLM, @@ -1003,7 +997,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == ModelType.T5: + if model_type == T5: return Seq2SeqLM( model_id=model_id, model_class=T5ForConditionalGeneration, @@ -1019,7 +1013,7 @@ def get_model( ] }, ) - if model_type == ModelType.IDEFICS: + if model_type == IDEFICS: if FLASH_ATTENTION: return IDEFICSSharded( model_id, @@ -1031,7 +1025,7 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == ModelType.IDEFICS2: + if model_type == IDEFICS2: if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, @@ -1048,7 +1042,7 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == ModelType.PALIGEMMA: + if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, @@ -1066,7 +1060,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == ModelType.LLAVA_NEXT: + if model_type == LLAVA_NEXT: if FLASH_ATTENTION: return VlmCausalLM( model_class=LlavaNextForConditionalGeneration, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index d8b514ac..212ab7a9 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -233,7 +233,7 @@ class CausalLMBatch(Batch): ] # Ensure that past_key_values tensors can be updated in-place - if isinstance(self.past_key_values[0], tuple): + if type(self.past_key_values[0]) is tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 3561bb0a..8a80ed68 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch): image_hidden_states = self.image_hidden_states[keep_indices] # Ensure that past_key_values tensors can be updated in-place - if isinstance(self.past_key_values[0], tuple): + if type(self.past_key_values[0]) is tuple: self.past_key_values = [list(layer) for layer in self.past_key_values] # Update tensors in-place to allow incremental garbage collection diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 56dc0dc1..79c001b0 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -253,7 +253,7 @@ class Seq2SeqLMBatch(Batch): ] # Ensure that past_key_values tensors can be updated in-place - if isinstance(self.past_key_values[0], tuple): + if type(self.past_key_values[0]) is tuple: self.past_key_values = [ [t for t in layer] for layer in self.past_key_values ]