diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 52499b33..d159df88 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -88,6 +88,9 @@ try: from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx @@ -106,7 +109,6 @@ if FLASH_ATTENTION: # __all__.append(FlashLlama) __all__.append(IDEFICSSharded) __all__.append(FlashMistral) - __all__.append(FlashMixtral) __all__.append(FlashDbrx) __all__.append(FlashPhi) __all__.append(FlashQwen2) @@ -773,13 +775,15 @@ def get_model( if model_type == MIXTRAL: if FLASH_ATTENTION: - return FlashMixtral( - model_id, - revision, + return FlashMistral( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py deleted file mode 100644 index 587d423f..00000000 --- a/server/text_generation_server/models/flash_mixtral.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from typing import Optional - -from text_generation_server.models.flash_mistral import BaseFlashMistral -from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( - MixtralConfig, - FlashMixtralForCausalLM, -) - - -class FlashMixtral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMixtral, self).__init__( - config_cls=MixtralConfig, - model_cls=FlashMixtralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index cd6078f1..4176aa05 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -8,8 +8,7 @@ from transformers import AutoTokenizer, AutoConfig from typing import Optional from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, + FlashMistral, ) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, @@ -24,7 +23,7 @@ from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) -class FlashQwen2(BaseFlashMistral): +class FlashQwen2(FlashMistral): def __init__( self, model_id: str, @@ -62,10 +61,6 @@ class FlashQwen2(BaseFlashMistral): config.quantize = quantize config.speculator = speculator - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -78,7 +73,7 @@ class FlashQwen2(BaseFlashMistral): self.cuda_graphs = {} torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( + super(FlashMistral, self).__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 369e9e4c..16c9a8b9 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -7,8 +7,7 @@ from typing import Optional from transformers.models.gpt2 import GPT2TokenizerFast from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, + FlashMistral, ) from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( Starcoder2Config, @@ -22,7 +21,7 @@ from text_generation_server.utils import ( # Starcoder2 has the same base as Mistral -class FlashStarcoder2(BaseFlashMistral): +class FlashStarcoder2(FlashMistral): def __init__( self, model_id: str, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1cdf37ea..708f8ac6 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,7 @@ from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, + FlashMistral, ) tracer = trace.get_tracer(__name__) @@ -239,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch -class VlmCausalLM(BaseFlashMistral): +class VlmCausalLM(FlashMistral): @property def batch_type(self) -> Type[VlmCausalLMBatch]: return VlmCausalLMBatch