mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
transformers format
This commit is contained in:
parent
e69eed8ea3
commit
5799e5cae9
@ -18,6 +18,7 @@ gen-server:
|
||||
|
||||
install: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install git+https://github.com/OlivierDehaene/megablocks#33fad2b0eae7c47b8fedfb3ad415af8169386918
|
||||
pip install -r requirements_cuda.txt
|
||||
pip install -e ".[bnb, accelerate, quantize, peft]"
|
||||
|
||||
|
@ -77,6 +77,18 @@ except ImportError as e:
|
||||
if MISTRAL:
|
||||
__all__.append(FlashMistral)
|
||||
|
||||
MIXTRAL = True
|
||||
try:
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Mixtral model: {e}")
|
||||
MIXTRAL = False
|
||||
|
||||
if MIXTRAL:
|
||||
__all__.append(FlashMixtral)
|
||||
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
@ -282,7 +294,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in ["mistral", "mixtral"]:
|
||||
if model_type == "mistral":
|
||||
if MISTRAL:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
@ -294,6 +306,17 @@ def get_model(
|
||||
)
|
||||
raise NotImplementedError("Mistral models requires flash attention v2")
|
||||
|
||||
if model_type == "mixtral":
|
||||
if MIXTRAL:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
|
||||
|
||||
if model_type == "opt":
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
|
@ -56,7 +56,7 @@ except ImportError:
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
model_type = "mistral"
|
||||
model_type = "mixtral"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
tokenized_input = tokenized_input[-r.truncate:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
@ -277,15 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
)
|
||||
|
||||
|
||||
class FlashMistral(FlashCausalLM):
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
architectures: List[str],
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
self,
|
||||
config_cls,
|
||||
model_cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
@ -305,14 +306,6 @@ class FlashMistral(FlashCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if "MixtralForCausalLM" in architectures:
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
|
||||
config_cls = MixtralConfig
|
||||
model_cls = FlashMixtralForCausalLM
|
||||
else:
|
||||
config_cls = MistralConfig
|
||||
model_cls = FlashMistralForCausalLM
|
||||
|
||||
config = config_cls.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
@ -332,7 +325,7 @@ class FlashMistral(FlashCausalLM):
|
||||
model = model_cls(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashMistral, self).__init__(
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
@ -404,3 +397,23 @@ class FlashMistral(FlashCausalLM):
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMistral, self).__init__(
|
||||
config_cls=MistralConfig,
|
||||
model_cls=FlashMistralForCausalLM,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
26
server/text_generation_server/models/flash_mixtral.py
Normal file
26
server/text_generation_server/models/flash_mixtral.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
|
||||
|
||||
|
||||
class FlashMixtral(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMixtral, self).__init__(
|
||||
config_cls=MixtralConfig,
|
||||
model_cls=FlashMixtralForCausalLM,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
@ -91,12 +91,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
# from torch.profiler import profile, ProfilerActivity
|
||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
# if self.model.rank == 0:
|
||||
# prefill_prof.export_chrome_trace("prefill.json")
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
@ -122,12 +118,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
# from torch.profiler import profile, ProfilerActivity
|
||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
# if self.model.rank == 0:
|
||||
# prefill_prof.export_chrome_trace("decode.json")
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
|
Loading…
Reference in New Issue
Block a user