Llava next dump.

This commit is contained in:
Nicolas Patry 2024-04-02 09:40:27 +00:00
parent f9958ee191
commit b68fc4deb1
8 changed files with 1638 additions and 59 deletions

View File

@ -67,6 +67,7 @@ try:
FlashSantacoderSharded, FlashSantacoderSharded,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
@ -571,6 +572,19 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next":
if FLASH_ATTENTION:
return LlavaNext(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
if sharded: if sharded:
raise NotImplementedError("sharded is not supported for AutoModel") raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq": if quantize == "gptq":

File diff suppressed because it is too large Load Diff

View File

@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -343,27 +342,27 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module): class MistralModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MistralLayer( MistralLayer(
layer_id, prefix=f"{prefix}.layers.{layer_id}",
config, config=config,
weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -415,13 +414,17 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = MistralModel(config, weights) self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window

View File

@ -0,0 +1,424 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from transformers import AutoModel, AutoModelForCausalLM
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_tower", config=config.vision_config, weights=weights
)
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
self.language_model = load_text_model(
prefix="language_model", config=config.text_config, weights=weights
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
# self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, attention_mask, labels
):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.pad_token_id)
)
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_image_tokens.max() * (num_image_patches - 1)
) + sequence_length
batch_indices, non_image_indices = torch.where(
input_ids != self.config.image_token_index
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
- 1
)
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
self.config.ignore_index,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_image_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_image_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_image_indices
]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
:, None
].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = (
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
):
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if inputs_embeds is None:
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
batch_size, num_patches, num_channels, height, width = (
pixel_values.shape
)
reshaped_pixel_values = pixel_values.view(
batch_size * num_patches, num_channels, height, width
)
image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True
)
selected_image_feature = image_features.hidden_states[
vision_feature_layer
]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds, attention_mask, labels, position_ids = (
self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
)
if labels is None:
labels = torch.full_like(
attention_mask, self.config.ignore_index
).to(torch.long)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif (
past_key_values is not None
and pixel_values is not None
and input_ids.shape[1] == 1
):
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1
extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat(
(attention_mask, extended_attention_mask), dim=1
)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
logits = outputs[0]
return logits

View File

@ -17,7 +17,7 @@ from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import ( from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text, IdeficsForVisionText2Text,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -25,7 +25,7 @@ from text_generation_server.utils import (
) )
class IDEFICSSharded(IdeficsCausalLM): class IDEFICSSharded(VlmCausalLM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -82,7 +82,7 @@ class IDEFICSSharded(IdeficsCausalLM):
model = IdeficsForVisionText2Text(config, weights) model = IdeficsForVisionText2Text(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__( super(VlmCausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -0,0 +1,89 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from transformers import (
AutoTokenizer,
AutoConfig,
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
# from transformers import AutoConfig, AutoTokenizer, AutoProcessor
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.vision_config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
)
model = LlavaNextForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(VlmCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -47,7 +47,7 @@ tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class IdeficsCausalLMBatch(Batch): class VlmCausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int] requests_idx_mapping: Dict[int, int]
@ -99,7 +99,7 @@ class IdeficsCausalLMBatch(Batch):
processor: ProcessorMixin, # Hack processor: ProcessorMixin, # Hack
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "IdeficsCausalLMBatch": ) -> "VlmCausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -127,21 +127,23 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
prompts = [] # TODO Check impact on idefics
for inp in inputs: # prompts = []
# Each input is encoded into a list, where each element of this input list is either a string or a URL # for inp in inputs:
prompts.append(split(inp)) # # Each input is encoded into a list, where each element of this input list is either a string or a URL
# prompts.append(split(inp))
# The processor replaces the call to tokenizer, and # The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL # a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
tokenized_inputs = processor( tokenized_inputs = processor(
prompts, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token # TODO Check impact on idefics
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device) ).to(device)
for _ in pb.requests: for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1] input_len = tokenized_inputs["input_ids"].shape[1]
@ -156,7 +158,7 @@ class IdeficsCausalLMBatch(Batch):
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
pixel_values = tokenized_inputs["pixel_values"] pixel_values = tokenized_inputs.get("pixel_values", None)
image_hidden_states = None image_hidden_states = None
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = input_ids.new_zeros( attention_mask = input_ids.new_zeros(
@ -165,16 +167,19 @@ class IdeficsCausalLMBatch(Batch):
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask # Do the same for image_attention_mask
image_attention_mask = input_ids.new_zeros( if pixel_values is None:
( image_attention_mask = None
pb.size, else:
max_input_length + padding_right_offset, image_attention_mask = input_ids.new_zeros(
tokenized_inputs["pixel_values"].size(1), (
pb.size,
max_input_length + padding_right_offset,
pixel_values.size(1),
)
) )
) image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ "image_attention_mask"
"image_attention_mask" ]
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
@ -207,7 +212,7 @@ class IdeficsCausalLMBatch(Batch):
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: def filter(self, request_ids: List[int]) -> Optional["VlmCausalLMBatch"]:
# It deletes requests from the batch. For instance when client lost connection # It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
@ -323,9 +328,7 @@ class IdeficsCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate( def concatenate(cls, batches: List["VlmCausalLMBatch"]) -> "VlmCausalLMBatch":
cls, batches: List["IdeficsCausalLMBatch"]
) -> "IdeficsCausalLMBatch":
# It adds new requests to the batch # It adds new requests to the batch
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
@ -564,7 +567,7 @@ class IdeficsCausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
class IdeficsCausalLM(Model): class VlmCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -574,7 +577,7 @@ class IdeficsCausalLM(Model):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
from text_generation_server.models.custom_modeling.idefics_modeling import ( from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text, VlmForVisionText2Text,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -601,7 +604,7 @@ class IdeficsCausalLM(Model):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
model = IdeficsForVisionText2Text.from_pretrained( model = VlmForVisionText2Text.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
@ -626,7 +629,7 @@ class IdeficsCausalLM(Model):
else: else:
tokenizer.add_special_tokens({"pad_token": "<unk>"}) tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(IdeficsCausalLM, self).__init__( super(VlmCausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
@ -635,8 +638,8 @@ class IdeficsCausalLM(Model):
) )
@property @property
def batch_type(self) -> Type[IdeficsCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return IdeficsCausalLMBatch return VlmCausalLMBatch
def forward( def forward(
self, self,
@ -672,24 +675,27 @@ class IdeficsCausalLM(Model):
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: IdeficsCausalLMBatch self, batch: VlmCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1: if batch.image_attention_mask is None:
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), image_attention_mask = None
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else: else:
image_attention_mask = batch.image_attention_mask[ if batch.input_ids.size(1) == 1:
:, : -batch.padding_right_offset # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
] # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, speculative_logits, past, image_hidden_states = self.forward( logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids, input_ids=batch.input_ids,

View File

@ -15,7 +15,7 @@ from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
@ -79,7 +79,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
pass pass
if ( if (
self.model.batch_type == IdeficsCausalLMBatch self.model.batch_type == VlmCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call ): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, request.batch,
@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns() start = time.time_ns()
if ( if (
self.model.batch_type == IdeficsCausalLMBatch self.model.batch_type == VlmCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call ): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, request.batch,