mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Addresses comments.
This commit is contained in:
parent
2259d2f78a
commit
9cc58d1cb3
@ -11,17 +11,16 @@ from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import GalacticaSharded
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
GPTNeoxForCausalLM,
|
||||
)
|
||||
@ -169,6 +168,11 @@ class ModelType(enum.Enum):
|
||||
"name": "Gemma",
|
||||
"url": "https://huggingface.co/google/gemma-7b",
|
||||
}
|
||||
PALIGEMMA = {
|
||||
"type": "paligemma",
|
||||
"name": "PaliGemma",
|
||||
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||
}
|
||||
GEMMA2 = {
|
||||
"type": "gemma2",
|
||||
"name": "Gemma2",
|
||||
@ -466,14 +470,16 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_id.startswith("facebook/galactica"):
|
||||
return GalacticaSharded(
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
# Yes galactica is just an OPT model.
|
||||
model_class=OPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=GalacticaCausalLMBatch,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -509,7 +515,7 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == BLOOM:
|
||||
return BLOOMSharded(
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=BloomForCausalLM,
|
||||
revision=revision,
|
||||
@ -517,6 +523,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=CausalLMBatchKeysLast,
|
||||
)
|
||||
elif model_type == MPT:
|
||||
return CausalLM(
|
||||
@ -527,6 +534,7 @@ def get_model(
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=CausalLMBatchKeysLast,
|
||||
)
|
||||
elif model_type == GPT2:
|
||||
if FLASH_ATTENTION:
|
||||
@ -666,6 +674,8 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -689,6 +699,8 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
@ -737,6 +749,8 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Dbrx works better in bfloat16.
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DbrxConfig,
|
||||
@ -765,6 +779,10 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
},
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
@ -947,7 +965,7 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == "paligemma":
|
||||
if model_type == PALIGEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
@ -956,6 +974,8 @@ def get_model(
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
batch_class=PaliGemmaBatch,
|
||||
|
@ -489,6 +489,11 @@ class CausalLMBatch(Batch):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMBatchKeysLast(Batch):
|
||||
keys_head_dim_last: bool = False
|
||||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
@ -498,14 +503,25 @@ class CausalLM(Model):
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
default_dtype=torch.float16,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
config_class=AutoConfig,
|
||||
batch_class=CausalLMBatch,
|
||||
):
|
||||
self.batch_class = batch_class
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
@ -612,6 +628,7 @@ class CausalLM(Model):
|
||||
self = cls.__new__(
|
||||
cls,
|
||||
)
|
||||
self.batch_class = CausalLMBatch
|
||||
super().__init__(
|
||||
self,
|
||||
model_id=model_id,
|
||||
@ -625,7 +642,7 @@ class CausalLM(Model):
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
return self.batch
|
||||
|
||||
# This is not used anymore
|
||||
# def decode(self, generated_ids: List[int]) -> str:
|
||||
|
@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
class GalacticaSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return GalacticaCausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
@ -74,19 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
|
||||
class PaliGemma(VlmCausalLM):
|
||||
@property
|
||||
def batch_type(self):
|
||||
return PaliGemmaBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.text_model.model.layers),
|
||||
model.text_model.model.num_key_value_heads,
|
||||
model.text_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
||||
|
@ -547,6 +547,7 @@ class Seq2SeqLM(Model):
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
default_dtype=torch.float16,
|
||||
trust_remote_code: bool = False,
|
||||
config_class=AutoConfig,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
@ -555,7 +556,15 @@ class Seq2SeqLM(Model):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
@ -1,83 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class ShardedSeq2SeqLM(Seq2SeqLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
config_class=AutoConfig,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
aliases=None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
)
|
||||
|
||||
model = model_class(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
Loading…
Reference in New Issue
Block a user