From 5799e5cae950df831f0bcd1e9901cdbfd2a2de13 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 11 Dec 2023 11:14:40 +0100 Subject: [PATCH] transformers format --- server/Makefile | 1 + .../text_generation_server/models/__init__.py | 25 +++++++- .../custom_modeling/flash_mixtral_modeling.py | 2 +- .../models/flash_mistral.py | 61 +++++++++++-------- .../models/flash_mixtral.py | 26 ++++++++ server/text_generation_server/server.py | 8 --- 6 files changed, 89 insertions(+), 34 deletions(-) create mode 100644 server/text_generation_server/models/flash_mixtral.py diff --git a/server/Makefile b/server/Makefile index 2810a528..513cb2d2 100644 --- a/server/Makefile +++ b/server/Makefile @@ -18,6 +18,7 @@ gen-server: install: gen-server pip install pip --upgrade + pip install git+https://github.com/OlivierDehaene/megablocks#33fad2b0eae7c47b8fedfb3ad415af8169386918 pip install -r requirements_cuda.txt pip install -e ".[bnb, accelerate, quantize, peft]" diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index aae81be2..72aa10b6 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -77,6 +77,18 @@ except ImportError as e: if MISTRAL: __all__.append(FlashMistral) +MIXTRAL = True +try: + from text_generation_server.models.flash_mixtral import FlashMixtral +except ImportError as e: + logger.warning(f"Could not import Mixtral model: {e}") + MIXTRAL = False + +if MIXTRAL: + __all__.append(FlashMixtral) + + + def get_model( model_id: str, revision: Optional[str], @@ -282,7 +294,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type in ["mistral", "mixtral"]: + if model_type == "mistral": if MISTRAL: return FlashMistral( model_id, @@ -294,6 +306,17 @@ def get_model( ) raise NotImplementedError("Mistral models requires flash attention v2") + if model_type == "mixtral": + if MIXTRAL: + return FlashMixtral( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") + if model_type == "opt": return OPTSharded( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index f8de6cf7..66753d5a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -56,7 +56,7 @@ except ImportError: class MixtralConfig(PretrainedConfig): - model_type = "mistral" + model_type = "mixtral" def __init__( self, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index a4806d16..5ce37164 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch): # Parse batch for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) + zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate :] + tokenized_input = tokenized_input[-r.truncate:] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -277,15 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch): ) -class FlashMistral(FlashCausalLM): +class BaseFlashMistral(FlashCausalLM): def __init__( - self, - model_id: str, - architectures: List[str], - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, + self, + config_cls, + model_cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -305,14 +306,6 @@ class FlashMistral(FlashCausalLM): trust_remote_code=trust_remote_code, ) - if "MixtralForCausalLM" in architectures: - from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM - config_cls = MixtralConfig - model_cls = FlashMixtralForCausalLM - else: - config_cls = MistralConfig - model_cls = FlashMistralForCausalLM - config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) @@ -332,7 +325,7 @@ class FlashMistral(FlashCausalLM): model = model_cls(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashMistral, self).__init__( + super(BaseFlashMistral, self).__init__( model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -404,3 +397,23 @@ class FlashMistral(FlashCausalLM): if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits + + +class FlashMistral(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + super(FlashMistral, self).__init__( + config_cls=MistralConfig, + model_cls=FlashMistralForCausalLM, + model_id=model_id, + revision=revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code + ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py new file mode 100644 index 00000000..c45ae50f --- /dev/null +++ b/server/text_generation_server/models/flash_mixtral.py @@ -0,0 +1,26 @@ +import torch + +from typing import Optional + +from text_generation_server.models.flash_mistral import BaseFlashMistral +from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM + + +class FlashMixtral(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + super(FlashMixtral, self).__init__( + config_cls=MixtralConfig, + model_cls=FlashMixtralForCausalLM, + model_id=model_id, + revision=revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code + ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b740976b..ebe066e3 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -91,12 +91,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - # from torch.profiler import profile, ProfilerActivity - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof: generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - # if self.model.rank == 0: - # prefill_prof.export_chrome_trace("prefill.json") return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], @@ -122,12 +118,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): else: batch = batches[0] - # from torch.profiler import profile, ProfilerActivity - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof: generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - # if self.model.rank == 0: - # prefill_prof.export_chrome_trace("decode.json") return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations],