mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
transformers format
This commit is contained in:
parent
e69eed8ea3
commit
5799e5cae9
@ -18,6 +18,7 @@ gen-server:
|
|||||||
|
|
||||||
install: gen-server
|
install: gen-server
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
|
pip install git+https://github.com/OlivierDehaene/megablocks#33fad2b0eae7c47b8fedfb3ad415af8169386918
|
||||||
pip install -r requirements_cuda.txt
|
pip install -r requirements_cuda.txt
|
||||||
pip install -e ".[bnb, accelerate, quantize, peft]"
|
pip install -e ".[bnb, accelerate, quantize, peft]"
|
||||||
|
|
||||||
|
@ -77,6 +77,18 @@ except ImportError as e:
|
|||||||
if MISTRAL:
|
if MISTRAL:
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
|
|
||||||
|
MIXTRAL = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import Mixtral model: {e}")
|
||||||
|
MIXTRAL = False
|
||||||
|
|
||||||
|
if MIXTRAL:
|
||||||
|
__all__.append(FlashMixtral)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -282,7 +294,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in ["mistral", "mixtral"]:
|
if model_type == "mistral":
|
||||||
if MISTRAL:
|
if MISTRAL:
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
@ -294,6 +306,17 @@ def get_model(
|
|||||||
)
|
)
|
||||||
raise NotImplementedError("Mistral models requires flash attention v2")
|
raise NotImplementedError("Mistral models requires flash attention v2")
|
||||||
|
|
||||||
|
if model_type == "mixtral":
|
||||||
|
if MIXTRAL:
|
||||||
|
return FlashMixtral(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -56,7 +56,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
class MixtralConfig(PretrainedConfig):
|
||||||
model_type = "mistral"
|
model_type = "mixtral"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
global SLIDING_WINDOW_BLOCKS
|
global SLIDING_WINDOW_BLOCKS
|
||||||
@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
):
|
):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
@ -277,15 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(FlashCausalLM):
|
class BaseFlashMistral(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
config_cls,
|
||||||
architectures: List[str],
|
model_cls,
|
||||||
revision: Optional[str] = None,
|
model_id: str,
|
||||||
quantize: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
global SLIDING_WINDOW_BLOCKS
|
global SLIDING_WINDOW_BLOCKS
|
||||||
@ -305,14 +306,6 @@ class FlashMistral(FlashCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "MixtralForCausalLM" in architectures:
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
|
|
||||||
config_cls = MixtralConfig
|
|
||||||
model_cls = FlashMixtralForCausalLM
|
|
||||||
else:
|
|
||||||
config_cls = MistralConfig
|
|
||||||
model_cls = FlashMistralForCausalLM
|
|
||||||
|
|
||||||
config = config_cls.from_pretrained(
|
config = config_cls.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
@ -332,7 +325,7 @@ class FlashMistral(FlashCausalLM):
|
|||||||
model = model_cls(config, weights)
|
model = model_cls(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashMistral, self).__init__(
|
super(BaseFlashMistral, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
@ -404,3 +397,23 @@ class FlashMistral(FlashCausalLM):
|
|||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMistral(BaseFlashMistral):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(FlashMistral, self).__init__(
|
||||||
|
config_cls=MistralConfig,
|
||||||
|
model_cls=FlashMistralForCausalLM,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
26
server/text_generation_server/models/flash_mixtral.py
Normal file
26
server/text_generation_server/models/flash_mixtral.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMixtral(BaseFlashMistral):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(FlashMixtral, self).__init__(
|
||||||
|
config_cls=MixtralConfig,
|
||||||
|
model_cls=FlashMixtralForCausalLM,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code
|
||||||
|
)
|
@ -91,12 +91,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# from torch.profiler import profile, ProfilerActivity
|
|
||||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
# if self.model.rank == 0:
|
|
||||||
# prefill_prof.export_chrome_trace("prefill.json")
|
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
@ -122,12 +118,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
# from torch.profiler import profile, ProfilerActivity
|
|
||||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
# if self.model.rank == 0:
|
|
||||||
# prefill_prof.export_chrome_trace("decode.json")
|
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
|
Loading…
Reference in New Issue
Block a user