diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5ea43290..52499b33 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -56,8 +56,12 @@ try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, + + # from text_generation_server.models.flash_llama import ( + # FlashLlama, + # ) + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, ) from text_generation_server.models.flash_qwen2 import ( FlashQwen2, @@ -81,7 +85,9 @@ try: from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.idefics2 import Idefics2 from text_generation_server.models.flash_mistral import FlashMistral - from text_generation_server.models.flash_mixtral import FlashMixtral + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx @@ -97,7 +103,7 @@ if FLASH_ATTENTION: __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) __all__.append(FlashSantacoderSharded) - __all__.append(FlashLlama) + # __all__.append(FlashLlama) __all__.append(IDEFICSSharded) __all__.append(FlashMistral) __all__.append(FlashMixtral) @@ -599,9 +605,10 @@ def get_model( elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: - return FlashLlama( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -743,12 +750,14 @@ def get_model( if model_type == MISTRAL: if FLASH_ATTENTION: return FlashMistral( - model_id, - revision, + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4f276ed4..6b7cceef 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,7 +10,12 @@ import numpy as np from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) from typing import Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata @@ -21,6 +26,12 @@ from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + hub, +) from text_generation_server.models.types import ( Batch, Tokens, @@ -803,25 +814,88 @@ class FlashCausalLM(Model): def __init__( self, model_id: str, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - num_layers: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - sliding_window: Optional[int] = None, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + default_dtype=torch.float16, + # self, + # model_id: str, + # model_class, + # tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + # num_layers: int, + # num_kv_heads: int, + # head_size: int, + # dtype: torch.dtype, + # device: torch.device, + # rank: int = 0, + # world_size: int = 1, + # sliding_window: Optional[int] = None, ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_size = head_size + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError(f"{model_class} is only available on GPU") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = AutoConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + prefix = "" + model = model_class(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} self.kv_cache = [] - super(FlashCausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, @@ -830,7 +904,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, + sliding_window=config.sliding_window, ) @property diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0f5746de..c2482dc2 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,24 +1,7 @@ import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import set_sliding_window -from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - MistralConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) ADAPTER_LAYERS = [ @@ -33,88 +16,7 @@ ADAPTER_LAYERS = [ ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} -class BaseFlashMistral(FlashCausalLM): - def __init__( - self, - model_cls, - model_id: str, - config_cls=AutoConfig, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashMistral is only available on GPU") - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = config_cls.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = model_cls(prefix, config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - num_layers, num_kv_heads, head_size = self.get_layer_config(model) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.model.layers), - model.model.num_key_value_heads, - model.model.head_size, - ) - +class FlashMistral(FlashCausalLM): @property def supports_adapter_loading(self) -> bool: return True @@ -183,25 +85,3 @@ class BaseFlashMistral(FlashCausalLM): def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - - -class FlashMistral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMistral, self).__init__( - config_cls=MistralConfig, - model_cls=FlashMistralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - )