From ed34cf0222a8879aa45839ebb89bc70a5cfe5134 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 14:17:57 +0200 Subject: [PATCH] Remove a lot of duplicated code. --- .../text_generation_server/models/__init__.py | 202 +++++++++++------- .../flash_santacoder_modeling.py | 1 + .../models/flash_causal_lm.py | 20 +- .../models/flash_dbrx.py | 100 --------- .../models/flash_gemma.py | 83 ------- .../models/flash_gemma2.py | 83 ------- .../models/flash_gpt2.py | 82 ------- .../models/flash_neox.py | 82 ------- .../models/flash_phi.py | 111 ---------- .../models/flash_qwen2.py | 88 -------- .../text_generation_server/models/flash_rw.py | 91 -------- .../models/flash_santacoder.py | 99 --------- .../text_generation_server/models/idefics2.py | 51 ----- .../models/llava_next.py | 46 ---- .../models/pali_gemma.py | 26 --- .../models/vlm_causal_lm.py | 28 ++- 16 files changed, 156 insertions(+), 1037 deletions(-) delete mode 100644 server/text_generation_server/models/flash_dbrx.py delete mode 100644 server/text_generation_server/models/flash_gemma.py delete mode 100644 server/text_generation_server/models/flash_gemma2.py delete mode 100644 server/text_generation_server/models/flash_gpt2.py delete mode 100644 server/text_generation_server/models/flash_neox.py delete mode 100644 server/text_generation_server/models/flash_phi.py delete mode 100644 server/text_generation_server/models/flash_qwen2.py delete mode 100644 server/text_generation_server/models/flash_rw.py delete mode 100644 server/text_generation_server/models/flash_santacoder.py delete mode 100644 server/text_generation_server/models/idefics2.py delete mode 100644 server/text_generation_server/models/llava_next.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d159df88..ddc1d461 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -53,47 +53,62 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_gpt2 import FlashGPT2 - from text_generation_server.models.flash_neox import FlashNeoXSharded - - # from text_generation_server.models.flash_llama import ( - # FlashLlama, - # ) + from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) - from text_generation_server.models.flash_qwen2 import ( - FlashQwen2, + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, ) - from text_generation_server.models.flash_cohere import ( - FlashCohere, + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, ) - from text_generation_server.models.flash_gemma import ( - FlashGemma, + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, ) - from text_generation_server.models.flash_gemma2 import ( - FlashGemma2, + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, ) from text_generation_server.models.pali_gemma import ( - PaliGemma, + PaliGemmaBatch, ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, ) from text_generation_server.models.idefics import IDEFICSSharded - from text_generation_server.models.llava_next import LlavaNext - from text_generation_server.models.idefics2 import Idefics2 + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) from text_generation_server.models.flash_mistral import FlashMistral + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( FlashMixtralForCausalLM, ) - from text_generation_server.models.flash_phi import FlashPhi - from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 - from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") @@ -102,20 +117,8 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(FlashGPT2) - __all__.append(FlashNeoXSharded) - __all__.append(FlashRWSharded) - __all__.append(FlashSantacoderSharded) - # __all__.append(FlashLlama) __all__.append(IDEFICSSharded) __all__.append(FlashMistral) - __all__.append(FlashDbrx) - __all__.append(FlashPhi) - __all__.append(FlashQwen2) - __all__.append(FlashStarcoder2) - __all__.append(FlashGemma) - __all__.append(FlashGemma2) - __all__.append(FlashCohere) MAMBA_AVAILABLE = True try: @@ -468,13 +471,16 @@ def get_model( and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) elif sharded: raise NotImplementedError( @@ -511,13 +517,15 @@ def get_model( elif model_type == GPT2: if FLASH_ATTENTION: try: - return FlashGPT2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) except RuntimeError as e: # Lots of legacy models with various weight names. @@ -543,13 +551,15 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: - return FlashNeoXSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: return GPTNeoxSharded( @@ -572,13 +582,15 @@ def get_model( elif model_type == PHI: if FLASH_ATTENTION: - return FlashPhi( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) else: return CausalLM( @@ -630,13 +642,15 @@ def get_model( ) if model_type == GEMMA: if FLASH_ATTENTION: - return FlashGemma( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) @@ -651,13 +665,15 @@ def get_model( ) elif model_type == GEMMA2: if FLASH_ATTENTION: - return FlashGemma2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) @@ -673,13 +689,15 @@ def get_model( if model_type == COHERE: if FLASH_ATTENTION: - return FlashCohere( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) @@ -695,13 +713,16 @@ def get_model( if model_type == DBRX: if FLASH_ATTENTION: - return FlashDbrx( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) @@ -720,24 +741,30 @@ def get_model( if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) else: return RW( @@ -799,12 +826,15 @@ def get_model( if model_type == STARCODER2: if FLASH_ATTENTION: - return FlashStarcoder2( - model_id, - revision, + return FlashMistral( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError( @@ -822,12 +852,15 @@ def get_model( if model_type == QWEN2: if FLASH_ATTENTION: - return FlashQwen2( - model_id, - revision, + return FlashMistral( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) @@ -874,34 +907,43 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS2: if FLASH_ATTENTION: - return Idefics2( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == "paligemma": if FLASH_ATTENTION: - return PaliGemma( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: - return LlavaNext( - model_id, - revision, + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 30989a37..a77a7655 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -466,6 +466,7 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() + config.transpose = config.architectures[0].startswith("GPT2") self.transformer = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6b7cceef..f2e66d56 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -822,19 +822,9 @@ class FlashCausalLM(Model): trust_remote_code: bool = False, lora_adapter_ids: Optional[list] = [], tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, default_dtype=torch.float16, - # self, - # model_id: str, - # model_class, - # tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, - # num_layers: int, - # num_kv_heads: int, - # head_size: int, - # dtype: torch.dtype, - # device: torch.device, - # rank: int = 0, - # world_size: int = 1, - # sliding_window: Optional[int] = None, + aliases=None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -868,7 +858,7 @@ class FlashCausalLM(Model): except Exception: pass - config = AutoConfig.from_pretrained( + config = config_class.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize @@ -881,7 +871,9 @@ class FlashCausalLM(Model): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, device, dtype, process_group=self.process_group, aliases=aliases + ) if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py deleted file mode 100644 index 2aba6a00..00000000 --- a/server/text_generation_server/models/flash_dbrx.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( - FlashDbrxForCausalLM, - DbrxConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashDbrx(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashDBRX is only available on GPU") - - try: - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = DbrxConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashDbrxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashDbrx, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py deleted file mode 100644 index 7e2b8780..00000000 --- a/server/text_generation_server/models/flash_gemma.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - FlashGemmaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py deleted file mode 100644 index 86cfc7e2..00000000 --- a/server/text_generation_server/models/flash_gemma2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import PretrainedConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( - FlashGemma2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = PretrainedConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py deleted file mode 100644 index 323fcafa..00000000 --- a/server/text_generation_server/models/flash_gpt2.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers.models.gpt2 import GPT2Tokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( - FlashGPT2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGPT2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGPT2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashGPT2ForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashGPT2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py deleted file mode 100644 index ac1fd573..00000000 --- a/server/text_generation_server/models/flash_neox.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashNeoXSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashGPTNeoXForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashNeoXSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.gpt_neox.layers), - num_kv_heads=model.gpt_neox.num_heads, - head_size=model.gpt_neox.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py deleted file mode 100644 index a530d1c3..00000000 --- a/server/text_generation_server/models/flash_phi.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_phi_modeling import ( - FlashPhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashPhi(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashPhi is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashPhiForCausalLM(config, weights) - if speculator: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(speculator).exists() and Path(speculator).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - speculator, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - speculator, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(speculator) / "config.json") - medusa_head = str(Path(speculator) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - - torch.distributed.barrier(group=self.process_group) - super(FlashPhi, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py deleted file mode 100644 index 4176aa05..00000000 --- a/server/text_generation_server/models/flash_qwen2.py +++ /dev/null @@ -1,88 +0,0 @@ -import math - -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models.flash_mistral import ( - FlashMistral, -) -from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( - Qwen2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashQwen2(FlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashQwen2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = Qwen2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(FlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py deleted file mode 100644 index b1f75adc..00000000 --- a/server/text_generation_server/models/flash_rw.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_rw_modeling import ( - RWConfig, - FlashRWForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashRWSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashRW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - ) - - config.quantize = quantize - config.speculator = speculator - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashRWForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashRWSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=model.transformer.cache_size, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py deleted file mode 100644 index e1a7b36e..00000000 --- a/server/text_generation_server/models/flash_santacoder.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List -import json -import os - -from huggingface_hub import hf_hub_download -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashSantacoderSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=True, - ) - config.quantize = quantize - config.speculator = speculator - config.transpose = config.architectures[0].startswith("GPT2") - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={"transformer.wte.weight": ["lm_head.weight"]}, - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashSantacoderForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashSantacoderSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=1, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py deleted file mode 100644 index 314c0500..00000000 --- a/server/text_generation_server/models/idefics2.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.idefics2 import ( - Idefics2ForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class Idefics2(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - # XXX: Extremely important to cap resolution in order to limit - # VRAM usage. - size={"longest_edge": 448, "shortest_edge": 378}, - ) - super().__init__( - model_cls=Idefics2ForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py deleted file mode 100644 index effe8b91..00000000 --- a/server/text_generation_server/models/llava_next.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class LlavaNext(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - super().__init__( - model_cls=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index a167e467..533a47ea 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -77,32 +77,6 @@ class PaliGemmaBatch(VlmCausalLMBatch): class PaliGemma(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - - super().__init__( - config_cls=AutoConfig, - model_cls=PaliGemmaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - @property def batch_type(self): return PaliGemmaBatch diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 708f8ac6..bee848e1 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_mistral import ( FlashMistral, ) +from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -240,9 +241,34 @@ class VlmCausalLMBatch(FlashCausalLMBatch): class VlmCausalLM(FlashMistral): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__(**kwargs) + @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) def forward( self,