mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Llava next dump.
This commit is contained in:
parent
f9958ee191
commit
b68fc4deb1
@ -67,6 +67,7 @@ try:
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
@ -571,6 +572,19 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||
|
||||
if sharded:
|
||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
|
1043
server/text_generation_server/models/custom_modeling/clip.py
Normal file
1043
server/text_generation_server/models/custom_modeling/clip.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
@ -343,27 +342,27 @@ class MistralLayer(nn.Module):
|
||||
|
||||
|
||||
class MistralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MistralLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -415,13 +414,17 @@ class MistralModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.model = MistralModel(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
@ -0,0 +1,424 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from transformers import AutoModel, AutoModelForCausalLM
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_vision_model(prefix, config, weights):
|
||||
if config.model_type == "clip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.clip import (
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
|
||||
return CLIPVisionTransformer(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
def load_text_model(prefix, config, weights):
|
||||
if config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower", config=config.vision_config, weights=weights
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(config)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
self.language_model = load_text_model(
|
||||
prefix="language_model", config=config.text_config, weights=weights
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
# self.post_init()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(
|
||||
input_ids[:, -1] == torch.tensor(self.pad_token_id)
|
||||
)
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (
|
||||
num_special_image_tokens.max() * (num_image_patches - 1)
|
||||
) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(
|
||||
input_ids != self.config.image_token_index
|
||||
)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = (
|
||||
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
|
||||
- 1
|
||||
)
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
embed_dim,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
dtype=attention_mask.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim),
|
||||
self.config.ignore_index,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
||||
batch_indices, non_image_indices
|
||||
]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
||||
batch_indices, non_image_indices
|
||||
]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[
|
||||
batch_indices, non_image_indices
|
||||
]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
|
||||
:, None
|
||||
].to(target_device)
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = (
|
||||
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
||||
(final_attention_mask == 0), 1
|
||||
)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
if pixel_values is not None and input_ids.shape[1] != 1:
|
||||
batch_size, num_patches, num_channels, height, width = (
|
||||
pixel_values.shape
|
||||
)
|
||||
reshaped_pixel_values = pixel_values.view(
|
||||
batch_size * num_patches, num_channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(
|
||||
reshaped_pixel_values, output_hidden_states=True
|
||||
)
|
||||
|
||||
selected_image_feature = image_features.hidden_states[
|
||||
vision_feature_layer
|
||||
]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
)
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds, attention_mask, labels, position_ids = (
|
||||
self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
)
|
||||
if labels is None:
|
||||
labels = torch.full_like(
|
||||
attention_mask, self.config.ignore_index
|
||||
).to(torch.long)
|
||||
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif (
|
||||
past_key_values is not None
|
||||
and pixel_values is not None
|
||||
and input_ids.shape[1] == 1
|
||||
):
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(
|
||||
first_layer_past_key_value.float().sum(-2) == 0
|
||||
)
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat(
|
||||
(attention_mask, extended_attention_mask), dim=1
|
||||
)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
return logits
|
@ -17,7 +17,7 @@ from transformers import LlamaTokenizerFast
|
||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||
IdeficsForVisionText2Text,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
@ -25,7 +25,7 @@ from text_generation_server.utils import (
|
||||
)
|
||||
|
||||
|
||||
class IDEFICSSharded(IdeficsCausalLM):
|
||||
class IDEFICSSharded(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
@ -82,7 +82,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
model = IdeficsForVisionText2Text(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(IdeficsCausalLM, self).__init__(
|
||||
super(VlmCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
89
server/text_generation_server/models/llava_next.py
Normal file
89
server/text_generation_server/models/llava_next.py
Normal file
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
# from transformers import AutoConfig, AutoTokenizer, AutoProcessor
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class LlavaNext(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
# 9b seems to work correctly enough in float16, but 80b seems
|
||||
# to be really saturating for f16.
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
self.device, self.dtype = device, dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.vision_config.quantize = quantize
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
|
||||
model = LlavaNextForConditionalGeneration(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(VlmCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
@ -47,7 +47,7 @@ tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdeficsCausalLMBatch(Batch):
|
||||
class VlmCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
@ -99,7 +99,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
processor: ProcessorMixin, # Hack
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
) -> "VlmCausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
@ -127,21 +127,23 @@ class IdeficsCausalLMBatch(Batch):
|
||||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
prompts.append(split(inp))
|
||||
# TODO Check impact on idefics
|
||||
# prompts = []
|
||||
# for inp in inputs:
|
||||
# # Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
# prompts.append(split(inp))
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
|
||||
tokenized_inputs = processor(
|
||||
prompts,
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
# TODO Check impact on idefics
|
||||
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
@ -156,7 +158,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
pixel_values = tokenized_inputs["pixel_values"]
|
||||
pixel_values = tokenized_inputs.get("pixel_values", None)
|
||||
image_hidden_states = None
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
@ -165,11 +167,14 @@ class IdeficsCausalLMBatch(Batch):
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
# Do the same for image_attention_mask
|
||||
if pixel_values is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
tokenized_inputs["pixel_values"].size(1),
|
||||
pixel_values.size(1),
|
||||
)
|
||||
)
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
@ -207,7 +212,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
|
||||
def filter(self, request_ids: List[int]) -> Optional["VlmCausalLMBatch"]:
|
||||
# It deletes requests from the batch. For instance when client lost connection
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
@ -323,9 +328,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(
|
||||
cls, batches: List["IdeficsCausalLMBatch"]
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
def concatenate(cls, batches: List["VlmCausalLMBatch"]) -> "VlmCausalLMBatch":
|
||||
# It adds new requests to the batch
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
@ -564,7 +567,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class IdeficsCausalLM(Model):
|
||||
class VlmCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
@ -574,7 +577,7 @@ class IdeficsCausalLM(Model):
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||
IdeficsForVisionText2Text,
|
||||
VlmForVisionText2Text,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
@ -601,7 +604,7 @@ class IdeficsCausalLM(Model):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = IdeficsForVisionText2Text.from_pretrained(
|
||||
model = VlmForVisionText2Text.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
@ -626,7 +629,7 @@ class IdeficsCausalLM(Model):
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "<unk>"})
|
||||
|
||||
super(IdeficsCausalLM, self).__init__(
|
||||
super(VlmCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
@ -635,8 +638,8 @@ class IdeficsCausalLM(Model):
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
||||
return IdeficsCausalLMBatch
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -672,11 +675,14 @@ class IdeficsCausalLM(Model):
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
self, batch: IdeficsCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
|
||||
self, batch: VlmCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.image_attention_mask is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
@ -15,7 +15,7 @@ from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
@ -79,7 +79,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
pass
|
||||
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
self.model.batch_type == VlmCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
self.model.batch_type == VlmCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
|
Loading…
Reference in New Issue
Block a user