mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
First working step.
This commit is contained in:
parent
b28946d695
commit
69cb084b5f
@ -88,6 +88,9 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
@ -106,7 +109,6 @@ if FLASH_ATTENTION:
|
|||||||
# __all__.append(FlashLlama)
|
# __all__.append(FlashLlama)
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
__all__.append(FlashMixtral)
|
|
||||||
__all__.append(FlashDbrx)
|
__all__.append(FlashDbrx)
|
||||||
__all__.append(FlashPhi)
|
__all__.append(FlashPhi)
|
||||||
__all__.append(FlashQwen2)
|
__all__.append(FlashQwen2)
|
||||||
@ -773,13 +775,15 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMixtral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashMixtralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
|
||||||
MixtralConfig,
|
|
||||||
FlashMixtralForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMixtral(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
super(FlashMixtral, self).__init__(
|
|
||||||
config_cls=MixtralConfig,
|
|
||||||
model_cls=FlashMixtralForCausalLM,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
@ -8,8 +8,7 @@ from transformers import AutoTokenizer, AutoConfig
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
FlashMistral,
|
||||||
set_sliding_window,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
@ -24,7 +23,7 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashQwen2(BaseFlashMistral):
|
class FlashQwen2(FlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -62,10 +61,6 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if config.sliding_window is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
@ -78,7 +73,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(BaseFlashMistral, self).__init__(
|
super(FlashMistral, self).__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -7,8 +7,7 @@ from typing import Optional
|
|||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
FlashMistral,
|
||||||
set_sliding_window,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
Starcoder2Config,
|
Starcoder2Config,
|
||||||
@ -22,7 +21,7 @@ from text_generation_server.utils import (
|
|||||||
|
|
||||||
|
|
||||||
# Starcoder2 has the same base as Mistral
|
# Starcoder2 has the same base as Mistral
|
||||||
class FlashStarcoder2(BaseFlashMistral):
|
class FlashStarcoder2(FlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -11,7 +11,7 @@ from transformers.image_processing_utils import select_best_resolution
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
FlashMistral,
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -239,7 +239,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(BaseFlashMistral):
|
class VlmCausalLM(FlashMistral):
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||||
return VlmCausalLMBatch
|
return VlmCausalLMBatch
|
||||||
|
Loading…
Reference in New Issue
Block a user