mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Addresses comments.
This commit is contained in:
parent
2259d2f78a
commit
9cc58d1cb3
@ -11,17 +11,16 @@ from pathlib import Path
|
|||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
MPTForCausalLM,
|
MPTForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
|
||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
BloomForCausalLM,
|
BloomForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.galactica import GalacticaSharded
|
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||||
GPTNeoxForCausalLM,
|
GPTNeoxForCausalLM,
|
||||||
)
|
)
|
||||||
@ -169,6 +168,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Gemma",
|
"name": "Gemma",
|
||||||
"url": "https://huggingface.co/google/gemma-7b",
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
}
|
}
|
||||||
|
PALIGEMMA = {
|
||||||
|
"type": "paligemma",
|
||||||
|
"name": "PaliGemma",
|
||||||
|
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||||
|
}
|
||||||
GEMMA2 = {
|
GEMMA2 = {
|
||||||
"type": "gemma2",
|
"type": "gemma2",
|
||||||
"name": "Gemma2",
|
"name": "Gemma2",
|
||||||
@ -466,14 +470,16 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_id.startswith("facebook/galactica"):
|
if model_id.startswith("facebook/galactica"):
|
||||||
return GalacticaSharded(
|
return CausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
# Yes galactica is just an OPT model.
|
||||||
model_class=OPTForCausalLM,
|
model_class=OPTForCausalLM,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=GalacticaCausalLMBatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -509,7 +515,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == BLOOM:
|
if model_type == BLOOM:
|
||||||
return BLOOMSharded(
|
return CausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=BloomForCausalLM,
|
model_class=BloomForCausalLM,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -517,6 +523,7 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == MPT:
|
elif model_type == MPT:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
@ -527,6 +534,7 @@ def get_model(
|
|||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
@ -666,6 +674,8 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
@ -689,6 +699,8 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
@ -737,6 +749,8 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Dbrx works better in bfloat16.
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
config_class=DbrxConfig,
|
config_class=DbrxConfig,
|
||||||
@ -765,6 +779,10 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
config_class=RWConfig,
|
config_class=RWConfig,
|
||||||
@ -947,7 +965,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == "paligemma":
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -956,6 +974,8 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
batch_class=PaliGemmaBatch,
|
batch_class=PaliGemmaBatch,
|
||||||
|
@ -489,6 +489,11 @@ class CausalLMBatch(Batch):
|
|||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CausalLMBatchKeysLast(Batch):
|
||||||
|
keys_head_dim_last: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -498,14 +503,25 @@ class CausalLM(Model):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
config_class=AutoConfig,
|
config_class=AutoConfig,
|
||||||
|
batch_class=CausalLMBatch,
|
||||||
):
|
):
|
||||||
|
self.batch_class = batch_class
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
@ -612,6 +628,7 @@ class CausalLM(Model):
|
|||||||
self = cls.__new__(
|
self = cls.__new__(
|
||||||
cls,
|
cls,
|
||||||
)
|
)
|
||||||
|
self.batch_class = CausalLMBatch
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self,
|
self,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -625,7 +642,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return self.batch
|
||||||
|
|
||||||
# This is not used anymore
|
# This is not used anymore
|
||||||
# def decode(self, generated_ids: List[int]) -> str:
|
# def decode(self, generated_ids: List[int]) -> str:
|
||||||
|
@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
tp_parallel=True,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
|
||||||
return GalacticaCausalLMBatch
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
||||||
|
@ -74,19 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
|
||||||
class PaliGemma(VlmCausalLM):
|
|
||||||
@property
|
|
||||||
def batch_type(self):
|
|
||||||
return PaliGemmaBatch
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.text_model.model.layers),
|
|
||||||
model.text_model.model.num_key_value_heads,
|
|
||||||
model.text_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
||||||
|
@ -547,6 +547,7 @@ class Seq2SeqLM(Model):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
config_class=AutoConfig,
|
config_class=AutoConfig,
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
@ -555,7 +556,15 @@ class Seq2SeqLM(Model):
|
|||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models import Seq2SeqLM
|
|
||||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
|
||||||
T5ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedSeq2SeqLM(Seq2SeqLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
model_class,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
config_class=AutoConfig,
|
|
||||||
tokenizer_class=AutoTokenizer,
|
|
||||||
aliases=None,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
config = config_class.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases=aliases,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(Seq2SeqLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
Loading…
Reference in New Issue
Block a user