diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 32c2168f..853968e1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,17 +11,16 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, ) -from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.galactica import GalacticaSharded +from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( GPTNeoxForCausalLM, ) @@ -169,6 +168,11 @@ class ModelType(enum.Enum): "name": "Gemma", "url": "https://huggingface.co/google/gemma-7b", } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } GEMMA2 = { "type": "gemma2", "name": "Gemma2", @@ -466,14 +470,16 @@ def get_model( ) if model_id.startswith("facebook/galactica"): - return GalacticaSharded( + return CausalLM( model_id=model_id, + # Yes galactica is just an OPT model. model_class=OPTForCausalLM, revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=GalacticaCausalLMBatch, ) if ( @@ -509,7 +515,7 @@ def get_model( ) if model_type == BLOOM: - return BLOOMSharded( + return CausalLM( model_id=model_id, model_class=BloomForCausalLM, revision=revision, @@ -517,6 +523,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == MPT: return CausalLM( @@ -527,6 +534,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: if FLASH_ATTENTION: @@ -666,6 +674,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -689,6 +699,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -737,6 +749,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=DbrxConfig, @@ -765,6 +779,10 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, @@ -947,7 +965,7 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "paligemma": + if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, @@ -956,6 +974,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, batch_class=PaliGemmaBatch, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index dce793f5..ea0e4ae4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -489,6 +489,11 @@ class CausalLMBatch(Batch): return len(self.requests) +@dataclass +class CausalLMBatchKeysLast(Batch): + keys_head_dim_last: bool = False + + class CausalLM(Model): def __init__( self, @@ -498,14 +503,25 @@ class CausalLM(Model): quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, config_class=AutoConfig, + batch_class=CausalLMBatch, ): + self.batch_class = batch_class self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype @@ -612,6 +628,7 @@ class CausalLM(Model): self = cls.__new__( cls, ) + self.batch_class = CausalLMBatch super().__init__( self, model_id=model_id, @@ -625,7 +642,7 @@ class CausalLM(Model): @property def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch + return self.batch # This is not used anymore # def decode(self, generated_ids: List[int]) -> str: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 30c92d90..2d43244a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch): padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) - - -class GalacticaSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - tokenizer.pad_token_id = config.pad_token_id - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index 533a47ea..3994ac70 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -74,19 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch): else: image_inputs = None return batch_tokenized_inputs, image_inputs - - -class PaliGemma(VlmCausalLM): - @property - def batch_type(self): - return PaliGemmaBatch - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 5d16c364..bc30404e 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -547,6 +547,7 @@ class Seq2SeqLM(Model): quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, trust_remote_code: bool = False, config_class=AutoConfig, tokenizer_class=AutoTokenizer, @@ -555,7 +556,15 @@ class Seq2SeqLM(Model): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/sharded_seq2seq_lm.py b/server/text_generation_server/models/sharded_seq2seq_lm.py deleted file mode 100644 index b73df83a..00000000 --- a/server/text_generation_server/models/sharded_seq2seq_lm.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from typing import List, Optional, Tuple - -from transformers import ( - AutoTokenizer, - AutoConfig, -) - -from text_generation_server.models import Seq2SeqLM -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class ShardedSeq2SeqLM(Seq2SeqLM): - def __init__( - self, - model_id: str, - model_class, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - config_class=AutoConfig, - tokenizer_class=AutoTokenizer, - aliases=None, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - config = config_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.bos_token_id = config.decoder_start_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases=aliases, - ) - - model = model_class(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(Seq2SeqLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - )