transformers format

This commit is contained in:
OlivierDehaene 2023-12-11 11:14:40 +01:00
parent e69eed8ea3
commit 5799e5cae9
6 changed files with 89 additions and 34 deletions

View File

@ -18,6 +18,7 @@ gen-server:
install: gen-server
pip install pip --upgrade
pip install git+https://github.com/OlivierDehaene/megablocks#33fad2b0eae7c47b8fedfb3ad415af8169386918
pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft]"

View File

@ -77,6 +77,18 @@ except ImportError as e:
if MISTRAL:
__all__.append(FlashMistral)
MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False
if MIXTRAL:
__all__.append(FlashMixtral)
def get_model(
model_id: str,
revision: Optional[str],
@ -282,7 +294,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type in ["mistral", "mixtral"]:
if model_type == "mistral":
if MISTRAL:
return FlashMistral(
model_id,
@ -294,6 +306,17 @@ def get_model(
)
raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
if model_type == "opt":
return OPTSharded(
model_id,

View File

@ -56,7 +56,7 @@ except ImportError:
class MixtralConfig(PretrainedConfig):
model_type = "mistral"
model_type = "mixtral"
def __init__(
self,

View File

@ -277,11 +277,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
class FlashMistral(FlashCausalLM):
class BaseFlashMistral(FlashCausalLM):
def __init__(
self,
config_cls,
model_cls,
model_id: str,
architectures: List[str],
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
@ -305,14 +306,6 @@ class FlashMistral(FlashCausalLM):
trust_remote_code=trust_remote_code,
)
if "MixtralForCausalLM" in architectures:
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
config_cls = MixtralConfig
model_cls = FlashMixtralForCausalLM
else:
config_cls = MistralConfig
model_cls = FlashMistralForCausalLM
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
@ -332,7 +325,7 @@ class FlashMistral(FlashCausalLM):
model = model_cls(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashMistral, self).__init__(
super(BaseFlashMistral, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
@ -404,3 +397,23 @@ class FlashMistral(FlashCausalLM):
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)

View File

@ -0,0 +1,26 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)

View File

@ -91,12 +91,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("prefill.json")
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
@ -122,12 +118,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else:
batch = batches[0]
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("decode.json")
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],