First working step.

This commit is contained in:
Nicolas Patry 2024-07-02 11:25:18 +00:00
parent b28946d695
commit 69cb084b5f
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
5 changed files with 15 additions and 48 deletions

View File

@ -88,6 +88,9 @@ try:
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
@ -106,7 +109,6 @@ if FLASH_ATTENTION:
# __all__.append(FlashLlama)
__all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashDbrx)
__all__.append(FlashPhi)
__all__.append(FlashQwen2)
@ -773,13 +775,15 @@ def get_model(
if model_type == MIXTRAL:
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
revision,
return FlashMistral(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))

View File

@ -1,31 +0,0 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -8,8 +8,7 @@ from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
FlashMistral,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
@ -24,7 +23,7 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashQwen2(BaseFlashMistral):
class FlashQwen2(FlashMistral):
def __init__(
self,
model_id: str,
@ -62,10 +61,6 @@ class FlashQwen2(BaseFlashMistral):
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -78,7 +73,7 @@ class FlashQwen2(BaseFlashMistral):
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
super(FlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,

View File

@ -7,8 +7,7 @@ from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
FlashMistral,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
@ -22,7 +21,7 @@ from text_generation_server.utils import (
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
class FlashStarcoder2(FlashMistral):
def __init__(
self,
model_id: str,

View File

@ -11,7 +11,7 @@ from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistral,
)
tracer = trace.get_tracer(__name__)
@ -239,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch
class VlmCausalLM(BaseFlashMistral):
class VlmCausalLM(FlashMistral):
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch