mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
First working step.
This commit is contained in:
parent
b28946d695
commit
69cb084b5f
@ -88,6 +88,9 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||
FlashMixtralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||
@ -106,7 +109,6 @@ if FLASH_ATTENTION:
|
||||
# __all__.append(FlashLlama)
|
||||
__all__.append(IDEFICSSharded)
|
||||
__all__.append(FlashMistral)
|
||||
__all__.append(FlashMixtral)
|
||||
__all__.append(FlashDbrx)
|
||||
__all__.append(FlashPhi)
|
||||
__all__.append(FlashQwen2)
|
||||
@ -773,13 +775,15 @@ def get_model(
|
||||
|
||||
if model_type == MIXTRAL:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashMistral(
|
||||
model_id=model_id,
|
||||
model_class=FlashMixtralForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||
|
@ -1,31 +0,0 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||
MixtralConfig,
|
||||
FlashMixtralForCausalLM,
|
||||
)
|
||||
|
||||
|
||||
class FlashMixtral(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMixtral, self).__init__(
|
||||
config_cls=MixtralConfig,
|
||||
model_cls=FlashMixtralForCausalLM,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
@ -8,8 +8,7 @@ from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
FlashMistral,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
@ -24,7 +23,7 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashQwen2(BaseFlashMistral):
|
||||
class FlashQwen2(FlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
@ -62,10 +61,6 @@ class FlashQwen2(BaseFlashMistral):
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
set_sliding_window(config.sliding_window)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
@ -78,7 +73,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
super(FlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -7,8 +7,7 @@ from typing import Optional
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
FlashMistral,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||
Starcoder2Config,
|
||||
@ -22,7 +21,7 @@ from text_generation_server.utils import (
|
||||
|
||||
|
||||
# Starcoder2 has the same base as Mistral
|
||||
class FlashStarcoder2(BaseFlashMistral):
|
||||
class FlashStarcoder2(FlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -11,7 +11,7 @@ from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistral,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -239,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(BaseFlashMistral):
|
||||
class VlmCausalLM(FlashMistral):
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
|
Loading…
Reference in New Issue
Block a user