diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index dd116c9e..eb5a8de7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -391,21 +391,6 @@ def get_model( ) model_type = config_dict.get("model_type", None) - transformers_causal_lm_class = CausalLM - - # Fast transformers path - transformers_model_class = getattr( - transformers, - modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""), - None, - ) - if ( - FLASH_TRANSFORMERS_BACKEND - and transformers_model_class is not None - and transformers_model_class._supports_flex_attn - ): - transformers_causal_lm_class = TransformersFlashCausalLM - quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: quantization_config = config_dict.get("compression_config", None) @@ -649,7 +634,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") ) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -756,7 +741,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -767,7 +752,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -792,7 +777,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -803,7 +788,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -840,7 +825,7 @@ def get_model( trust_remote_code=trust_remote_code, ) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -863,7 +848,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return transformers_causal_lm_class.fallback( + return TransformersFlashCausalLM.fallback( model_id, revision, quantize=quantize, @@ -887,7 +872,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -913,12 +898,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif ( - model_type == LLAMA - or model_type == BAICHUAN - or model_type == PHI3 - or model_type == GRANITE - ): + elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -931,12 +911,8 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - # elif sharded: - # raise NotImplementedError( - # FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") - # ) - else: - return transformers_causal_lm_class.fallback( + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( model_id, revision, quantize=quantize, @@ -944,6 +920,47 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + elif model_type == BAICHUAN: + if FLASH_ATTENTION: + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == GEMMA: if FLASH_ATTENTION: return FlashCausalLM( @@ -959,10 +976,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -988,7 +1014,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1010,10 +1036,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1041,7 +1076,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1091,7 +1126,7 @@ def get_model( config_class=RWConfig, ) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1113,10 +1148,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1138,10 +1182,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1163,12 +1216,21 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1190,10 +1252,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1339,8 +1410,6 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) - if sharded: - raise NotImplementedError("sharded is not supported for AutoModel") if quantize == "gptq": raise NotImplementedError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" @@ -1353,8 +1422,19 @@ def get_model( raise NotImplementedError("Eetq quantization is not supported for AutoModel") elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return transformers_causal_lm_class.fallback( + + # Fast transformers if available + transformers_model_class = getattr( + transformers, + modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""), + None, + ) + if ( + FLASH_TRANSFORMERS_BACKEND + and transformers_model_class is not None + and transformers_model_class._supports_flex_attn + ): + return TransformersFlashCausalLM.fallback( model_id, revision, quantize=quantize, @@ -1362,6 +1442,10 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + + if sharded: + raise NotImplementedError("sharded is not supported for AutoModel") + if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM.fallback( model_id, @@ -1375,7 +1459,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index c2669dba..36de89b4 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -260,4 +260,11 @@ class TransformersFlashCausalLM(FlashCausalLM): hidden_states = hidden_states[lm_head_indices] logits = self.model.lm_head(hidden_states) + # For Granite while next transformers version is released and we can use `lm_head_indices` natively + if hasattr(self.model.config, "logits_scaling"): + logits = logits / self.model.config.logits_scaling + # For Cohere for similar reasons + elif hasattr(self.model, "logit_scale"): + logits = logits * self.model.logit_scale + return logits, None