Update by abstracting away text model.

This commit is contained in:
Nicolas Patry 2024-04-03 16:41:01 +00:00
parent b68fc4deb1
commit b8be0d1ae7
10 changed files with 1951 additions and 1111 deletions

View File

@ -281,9 +281,8 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -337,27 +336,36 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(
layer_id, prefix=(
config, f"model.layers.{layer_id}"
weights, if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -406,13 +414,13 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = FlashLlamaModel(config, weights) self.model = FlashLlamaModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
@ -426,6 +434,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(

View File

@ -113,7 +113,13 @@ def load_vision_model(prefix, config, weights):
def load_text_model(prefix, config, weights): def load_text_model(prefix, config, weights):
if config.model_type == "mistral": if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,
) )
@ -124,22 +130,25 @@ def load_text_model(prefix, config, weights):
class LlavaNextForConditionalGeneration(nn.Module): class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model( # self.vision_tower = load_vision_model(
prefix="vision_tower", config=config.vision_config, weights=weights # prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights
) # )
self.multi_modal_projector = LlavaNextMultiModalProjector(config) # self.multi_modal_projector = LlavaNextMultiModalProjector(config)
self.image_newline = weights.get_tensor("image_newline") self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa config.text_config.use_medusa = config.use_medusa
self.language_model = load_text_model( self.language_model = load_text_model(
prefix="language_model", config=config.text_config, weights=weights prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
) )
self.pad_token_id = ( self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1 config.pad_token_id if config.pad_token_id is not None else -1
@ -257,168 +266,141 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
vision_feature_layer: Optional[int] = None, vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None, vision_feature_select_strategy: Optional[str] = None,
): ):
vision_feature_layer = ( # vision_feature_layer = (
vision_feature_layer # vision_feature_layer
if vision_feature_layer is not None # if vision_feature_layer is not None
else self.config.vision_feature_layer # else self.config.vision_feature_layer
# )
# vision_feature_select_strategy = (
# vision_feature_select_strategy
# if vision_feature_select_strategy is not None
# else self.config.vision_feature_select_strategy
# )
# if cu_seqlen_prefill is not None:
# pass
# # # 1. Extract the input embeddings
# # inputs_embeds = self.get_input_embeddings()(input_ids)
# # # 2. Merge text and images
# # if pixel_values is not None and input_ids.shape[1] != 1:
# # batch_size, num_patches, num_channels, height, width = (
# # pixel_values.shape
# # )
# # reshaped_pixel_values = pixel_values.view(
# # batch_size * num_patches, num_channels, height, width
# # )
# # image_features = self.vision_tower(
# # reshaped_pixel_values, output_hidden_states=True
# # )
# # selected_image_feature = image_features.hidden_states[
# # vision_feature_layer
# # ]
# # if vision_feature_select_strategy == "default":
# # selected_image_feature = selected_image_feature[:, 1:]
# # elif vision_feature_select_strategy == "full":
# # selected_image_feature = selected_image_feature
# # image_features = self.multi_modal_projector(selected_image_feature)
# # # split up image_features for each of the individual images
# # # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# # # if we assume each image has 5 image features (base image + 4 patches)
# # split_sizes = [image.shape[0] for image in pixel_values]
# # image_features = torch.split(image_features, split_sizes, dim=0)
# # # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
# # height = width = (
# # self.config.vision_config.image_size
# # // self.config.vision_config.patch_size
# # )
# # new_image_features = []
# # for image_idx, image_feature in enumerate(image_features):
# # if image_feature.shape[0] > 1:
# # base_image_feature = image_feature[0]
# # image_feature = image_feature[1:]
# # if height * width != base_image_feature.shape[0]:
# # raise ValueError(
# # "The number of patches is not consistent with the image size."
# # )
# # num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# # image_sizes[image_idx],
# # self.config.image_grid_pinpoints,
# # self.config.vision_config.image_size,
# # )
# # image_feature = image_feature.view(
# # num_patch_height, num_patch_width, height, width, -1
# # )
# # image_feature = image_feature.permute(
# # 4, 0, 2, 1, 3
# # ).contiguous()
# # image_feature = image_feature.flatten(1, 2).flatten(2, 3)
# # image_feature = unpad_image(
# # image_feature, image_sizes[image_idx]
# # )
# # image_feature = torch.cat(
# # (
# # image_feature,
# # self.image_newline[:, None, None].expand(
# # *image_feature.shape[:-1], 1
# # ),
# # ),
# # dim=-1,
# # )
# # image_feature = image_feature.flatten(1, 2).transpose(0, 1)
# # image_feature = torch.cat(
# # (base_image_feature, image_feature), dim=0
# # )
# # else:
# # image_feature = image_feature[0]
# # image_feature = torch.cat(
# # (image_feature, self.image_newline[None]), dim=0
# # )
# # new_image_features.append(image_feature)
# # image_features = torch.stack(new_image_features, dim=0)
# # inputs_embeds, attention_mask, labels, position_ids = (
# # self._merge_input_ids_with_image_features(
# # image_features, inputs_embeds, input_ids, attention_mask, labels
# # )
# # )
# # if labels is None:
# # labels = torch.full_like(
# # attention_mask, self.config.ignore_index
# # ).to(torch.long)
logits = self.language_model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
lm_head_indices,
) )
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if inputs_embeds is None:
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
batch_size, num_patches, num_channels, height, width = (
pixel_values.shape
)
reshaped_pixel_values = pixel_values.view(
batch_size * num_patches, num_channels, height, width
)
image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True
)
selected_image_feature = image_features.hidden_states[
vision_feature_layer
]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds, attention_mask, labels, position_ids = (
self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
)
if labels is None:
labels = torch.full_like(
attention_mask, self.config.ignore_index
).to(torch.long)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif (
past_key_values is not None
and pixel_values is not None
and input_ids.shape[1] == 1
):
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1
extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat(
(attention_mask, extended_attention_mask), dim=1
)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
logits = outputs[0]
return logits return logits

View File

@ -1043,7 +1043,12 @@ class FlashCausalLM(Model):
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids try:
batch.input_lengths_tensor += accepted_ids
except Exception:
import ipdb
ipdb.set_trace()
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
if prefill and prefill_logprobs: if prefill and prefill_logprobs:

View File

@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,

View File

@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, AutoTokenizer from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type
@ -301,14 +301,15 @@ class FlashMistralBatch(FlashCausalLMBatch):
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
config_cls,
model_cls, model_cls,
model_id: str, model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -317,22 +318,13 @@ class BaseFlashMistral(FlashCausalLM):
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
try: tokenizer = tokenizer_class.from_pretrained(
tokenizer = LlamaTokenizerFast.from_pretrained( model_id,
model_id, revision=revision,
revision=revision, padding_side="left",
padding_side="left", truncation_side="left",
truncation_side="left", trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code, )
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained( config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
@ -341,10 +333,12 @@ class BaseFlashMistral(FlashCausalLM):
config.use_medusa = use_medusa config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if getattr(config, "sliding_window", None) is not None:
set_sliding_window( set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
) )
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -353,17 +347,19 @@ class BaseFlashMistral(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = model_cls(config, weights) prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {} self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=num_layers,
num_kv_heads=model.model.num_key_value_heads, num_kv_heads=num_kv_heads,
head_size=model.model.head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
@ -371,6 +367,16 @@ class BaseFlashMistral(FlashCausalLM):
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
) )
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
def max_past(self) -> int:
return self.model.max_past
@property @property
def batch_type(self) -> Type[FlashMistralBatch]: def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch return FlashMistralBatch
@ -485,11 +491,11 @@ class BaseFlashMistral(FlashCausalLM):
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.model.max_past is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.model.max_past, max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs padded_bs = bs

View File

@ -17,7 +17,7 @@ from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import ( from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text, IdeficsForVisionText2Text,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -25,7 +25,7 @@ from text_generation_server.utils import (
) )
class IDEFICSSharded(VlmCausalLM): class IDEFICSSharded(IdeficsCausalLM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +1,15 @@
import torch import torch
import torch.distributed
from typing import List, Optional, Tuple from typing import Optional
from transformers import ( from transformers import (
AutoTokenizer,
AutoConfig,
AutoProcessor, AutoProcessor,
) )
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration, LlavaNextForConditionalGeneration,
) )
# from transformers import AutoConfig, AutoTokenizer, AutoProcessor
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class LlavaNext(VlmCausalLM): class LlavaNext(VlmCausalLM):
@ -31,59 +22,15 @@ class LlavaNext(VlmCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.vision_config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_id, model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision, revision=revision,
padding_side="left", quantize=quantize,
truncation_side="left", use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
)
model = LlavaNextForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(VlmCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,869 +1,50 @@
import torch import re
import time
from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import (
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models.flash_mistral import (
from text_generation_server.models.types import ( BaseFlashMistral,
Batch, FlashMistralBatch,
Tokens,
Generation,
GeneratedText,
) )
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
import re tracer = trace.get_tracer(__name__)
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string): def split(string) -> List[Dict[str, str]]:
parts = [] parts = []
cursor = 0 cursor = 0
for pattern in IMAGES.finditer(string): for pattern in IMAGES.finditer(string):
start = pattern.start() start = pattern.start()
if start != cursor: if start != cursor:
parts.append(string[cursor:start]) parts.append({"type": "text", "content": string[cursor:start]})
parts.append(pattern.group(1)) parts.append({"type": "image", "content": pattern.group(1)})
cursor = pattern.end() cursor = pattern.end()
if cursor != len(string): if cursor != len(string):
parts.append(string[cursor:]) parts.append({"type": "text", "content": string[cursor:]})
return parts return parts
tracer = trace.get_tracer(__name__) class VlmCausalLMBatch(FlashMistralBatch):
pass
@dataclass class VlmCausalLM(BaseFlashMistral):
class VlmCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
pixel_values: Optional[torch.Tensor]
image_hidden_states: Optional[torch.Tensor]
image_attention_mask: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack
dtype: torch.dtype,
device: torch.device,
) -> "VlmCausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# TODO Check impact on idefics
# prompts = []
# for inp in inputs:
# # Each input is encoded into a list, where each element of this input list is either a string or a URL
# prompts.append(split(inp))
# The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_truncation,
# TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(
input_len - 5
) # To decode without potential fallbacks errors
read_offsets.append(
input_len
) # To decode without potential fallbacks errors
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs.get("pixel_values", None)
image_hidden_states = None
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask
if pixel_values is None:
image_attention_mask = None
else:
image_attention_mask = input_ids.new_zeros(
(
pb.size,
max_input_length + padding_right_offset,
pixel_values.size(1),
)
)
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask"
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(
1, dim=1
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["VlmCausalLMBatch"]:
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Do the same for pixel_values and image_attention_mask
pixel_values = self.pixel_values[keep_indices]
self.image_attention_mask = self.image_attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
if self.image_hidden_states is None:
image_hidden_states = None
else:
image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.pixel_values = pixel_values
self.image_hidden_states = image_hidden_states
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["VlmCausalLMBatch"]) -> "VlmCausalLMBatch":
# It adds new requests to the batch
# Used for padding
total_batch_size = 0
max_input_length = 0
max_num_images = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
max_num_images = max(max_num_images, batch.pixel_values.size(1))
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
pixel_values = None
image_hidden_states = None
image_attention_mask = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
)
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
(
total_batch_size,
max_input_length + padding_right_offset,
max_num_images,
)
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
:curr_batch_max_num_images,
] = batch.image_attention_mask[
:, batch_left_offset : -batch.padding_right_offset, :
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
start_index = end_index
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
padded_past_values_shape = (
total_batch_size,
num_heads,
max_input_length - 1,
head_dim,
)
if batches[0].keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_input_length - 1,
)
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values,
image_hidden_states=image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
class VlmCausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
from text_generation_server.models.custom_modeling.idefics_modeling import (
VlmForVisionText2Text,
)
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = VlmForVisionText2Text.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(VlmCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[FlashMistralBatch]:
return VlmCausalLMBatch return FlashMistralBatch
def forward( def get_layer_config(self, model) -> Tuple[int, int, int]:
self,
input_ids,
attention_mask,
position_ids,
pixel_values,
image_hidden_states,
image_attention_mask,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_hidden_states": image_hidden_states,
"image_attention_mask": image_attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs, speculative_logits = self.model.forward(**kwargs)
return ( return (
outputs.logits, len(model.language_model.model.layers),
speculative_logits, model.language_model.model.num_key_value_heads,
outputs.past_key_values, model.language_model.model.head_size,
outputs.image_hidden_states,
) )
@tracer.start_as_current_span("generate_token") def max_past(self) -> Optional[int]:
def generate_token( return getattr(self.model.language_model, "max_past", None)
self, batch: VlmCausalLMBatch
) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.image_attention_mask is None:
image_attention_mask = None
else:
if batch.input_ids.size(1) == 1:
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
attention_mask=attention_mask,
position_ids=batch.position_ids,
pixel_values=batch.pixel_values,
image_hidden_states=batch.image_hidden_states,
image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values,
)
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1
# Update past key values
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -13,9 +13,10 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@ -78,9 +79,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
except ImportError: except ImportError:
pass pass
if ( if self.model.batch_type in {
self.model.batch_type == VlmCausalLMBatch IdeficsCausalLMBatch,
): # Hack, i would rather use kwargs in the `from_pb` call VlmCausalLMBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
@ -100,9 +102,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns() start = time.time_ns()
if ( if self.model.batch_type in {
self.model.batch_type == VlmCausalLMBatch IdeficsCausalLMBatch,
): # Hack, i would rather use kwargs in the `from_pb` call VlmCausalLMBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,