transformers format

This commit is contained in:
OlivierDehaene 2023-12-11 11:14:40 +01:00
parent e69eed8ea3
commit 5799e5cae9
6 changed files with 89 additions and 34 deletions

View File

@ -18,6 +18,7 @@ gen-server:
install: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install git+https://github.com/OlivierDehaene/megablocks#33fad2b0eae7c47b8fedfb3ad415af8169386918
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft]" pip install -e ".[bnb, accelerate, quantize, peft]"

View File

@ -77,6 +77,18 @@ except ImportError as e:
if MISTRAL: if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False
if MIXTRAL:
__all__.append(FlashMixtral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -282,7 +294,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in ["mistral", "mixtral"]: if model_type == "mistral":
if MISTRAL: if MISTRAL:
return FlashMistral( return FlashMistral(
model_id, model_id,
@ -294,6 +306,17 @@ def get_model(
) )
raise NotImplementedError("Mistral models requires flash attention v2") raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(
model_id, model_id,

View File

@ -56,7 +56,7 @@ except ImportError:
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
model_type = "mistral" model_type = "mixtral"
def __init__( def __init__(
self, self,

View File

@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
): ):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -277,15 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch):
) )
class FlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
model_id: str, config_cls,
architectures: List[str], model_cls,
revision: Optional[str] = None, model_id: str,
quantize: Optional[str] = None, revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
): ):
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
@ -305,14 +306,6 @@ class FlashMistral(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "MixtralForCausalLM" in architectures:
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
config_cls = MixtralConfig
model_cls = FlashMixtralForCausalLM
else:
config_cls = MistralConfig
model_cls = FlashMistralForCausalLM
config = config_cls.from_pretrained( config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
@ -332,7 +325,7 @@ class FlashMistral(FlashCausalLM):
model = model_cls(config, weights) model = model_cls(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashMistral, self).__init__( super(BaseFlashMistral, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
@ -404,3 +397,23 @@ class FlashMistral(FlashCausalLM):
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits return logits
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)

View File

@ -0,0 +1,26 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)

View File

@ -91,12 +91,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("prefill.json")
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
@ -122,12 +118,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else: else:
batch = batches[0] batch = batches[0]
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("decode.json")
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],