Improve vlm support (add idefics3 support) (#2437)

* feat: expand vlm support and add image token logic and tests

* fix: avoid unused perceiver config

* feat: integrate image tokens into inputs embeds

* feat: add simple idefics3 test

* feat: update docs, image token logic and weight names

* fix: improve image processing

* feat: improve prefix for idefics3

* fix: bump idefics3 tests and snapshots

* fix: improve text model loading

* feat: consolidate changes with existing vlms and add support and test for smolvlm

* fix: create new idefic3 file, simplify logic and adjust llama weight loading

* fix: lint with ruff

* fix: clean up idefics 3 and improve prefix handling

* fix: improve typing

* fix: improve prompt_split_image with ref to original impl

* fix: adjust ruff lints and small refactors

* fix: adjust FlashLlamaModel prefix logic
This commit is contained in:
drbh 2025-01-09 10:35:32 -05:00 committed by GitHub
parent a9c7d2e3b6
commit da5ab46705
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 988 additions and 43 deletions

View File

@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)

View File

@ -354,6 +354,7 @@ def launcher(event_loop):
kv_cache_dtype: Optional[str] = None, kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_input_tokens: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None, lora_adapters: Optional[List[str]] = None,
@ -402,6 +403,9 @@ def launcher(event_loop):
if max_input_length: if max_input_length:
args.append("--max-input-length") args.append("--max-input-length")
args.append(str(max_input_length)) args.append(str(max_input_length))
if max_input_tokens:
args.append("--max-input-tokens")
args.append(str(max_input_tokens))
if max_batch_prefill_tokens: if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens") args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens)) args.append(str(max_batch_prefill_tokens))

View File

@ -0,0 +1,67 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 9,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2684,
"logprob": -0.24902344,
"special": false,
"text": " There"
},
{
"id": 374,
"logprob": -0.0703125,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.23535156,
"special": false,
"text": " a"
},
{
"id": 35372,
"logprob": -0.125,
"special": false,
"text": " statue"
},
{
"id": 304,
"logprob": -0.30273438,
"special": false,
"text": " in"
},
{
"id": 279,
"logprob": -0.20507812,
"special": false,
"text": " the"
},
{
"id": 2217,
"logprob": -0.076171875,
"special": false,
"text": " image"
},
{
"id": 13,
"logprob": -0.053710938,
"special": false,
"text": "."
},
{
"id": 128258,
"logprob": -0.011352539,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " There is a statue in the image."
}

View File

@ -0,0 +1,61 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.118652344,
"special": false,
"text": " A"
},
{
"id": 11426,
"logprob": -0.28320312,
"special": false,
"text": " bee"
},
{
"id": 335,
"logprob": -0.95703125,
"special": false,
"text": " on"
},
{
"id": 253,
"logprob": -0.06982422,
"special": false,
"text": " a"
},
{
"id": 11986,
"logprob": -0.49414062,
"special": false,
"text": " pink"
},
{
"id": 8525,
"logprob": -0.07763672,
"special": false,
"text": " flower"
},
{
"id": 30,
"logprob": -1.0703125,
"special": false,
"text": "."
},
{
"id": 49154,
"logprob": -0.092285156,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " A bee on a pink flower."
}

View File

@ -0,0 +1,31 @@
import pytest
@pytest.fixture(scope="module")
def flash_idefics3_next_handle(launcher):
with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_idefics3_next(flash_idefics3_next_handle):
await flash_idefics3_next_handle.health(300)
return flash_idefics3_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot):
ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
query = "What is in this image?"
response = await flash_idefics3_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " There is a statue in the image."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 9
assert response == response_snapshot

View File

@ -0,0 +1,31 @@
import pytest
@pytest.fixture(scope="module")
def flash_smolvlm_next_handle(launcher):
with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_smolvlm_next(flash_smolvlm_next_handle):
await flash_smolvlm_next_handle.health(300)
return flash_smolvlm_next_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot):
ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"
query = "What is in this image?"
response = await flash_smolvlm_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " A bee on a pink flower."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 8
assert response == response_snapshot

View File

@ -110,6 +110,24 @@ pub struct ClipVisionModel {
patch_size: usize, patch_size: usize,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics3 {}
impl Idefics3 {
pub fn get_max_longest_edge(&self) -> usize {
364
}
pub fn get_number_of_features(&self) -> usize {
169
}
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
1456
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 {} pub struct Idefics2 {}
@ -178,6 +196,7 @@ pub enum Config {
Idefics, Idefics,
Mllama, Mllama,
Idefics2(Idefics2), Idefics2(Idefics2),
Idefics3(Idefics3),
Ssm, Ssm,
GptBigcode, GptBigcode,
Granite, Granite,

View File

@ -170,6 +170,7 @@ impl TokenizerConfigToken {
#[serde(tag = "processor_class")] #[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig { pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor), Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor),
} }
impl HubPreprocessorConfig { impl HubPreprocessorConfig {

View File

@ -614,6 +614,73 @@ fn image_tokens(
image_string image_string
} }
Idefics3(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
const GLOBAL_IMG: &str = "<global-img>";
let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();
// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
let (height, width) = if height > max_longest_edge_for_image_resize
|| width > max_longest_edge_for_image_resize
{
let aspect_ratio = height as f32 / width as f32;
if height > width {
(
max_longest_edge_for_image_resize,
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
)
} else {
(
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
max_longest_edge_for_image_resize,
)
}
} else {
(height, width)
};
let image_seq_len = config.get_number_of_features();
let max_edge = config.get_max_longest_edge();
let (image_rows, image_cols) = if height > max_edge || width > max_edge {
(
(height as f32 / max_edge as f32).ceil() as usize,
(width as f32 / max_edge as f32).ceil() as usize,
)
} else {
(0, 0)
};
let mut image_string = String::new();
if image_rows == 0 && image_cols == 0 {
// Single image case
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
} else {
// Split image case
for n_h in 0..image_rows {
for n_w in 0..image_cols {
image_string.push_str(FAKE);
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
image_string.push_str(&IMAGE.repeat(image_seq_len));
}
image_string.push('\n');
}
image_string.push('\n');
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
}
image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)), Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!( Qwen2Vl(config) => format!(
@ -647,7 +714,8 @@ fn prepare_input<T: TokenizerTrait>(
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
| Qwen2Vl(_)),
) => { ) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -152,6 +152,9 @@ try:
from text_generation_server.models.custom_modeling.idefics2 import ( from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration, Idefics2ForConditionalGeneration,
) )
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import ( from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
) )
@ -188,6 +191,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True, "multimodal": True,
} }
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = { LLAVA_NEXT = {
"type": "llava_next", "type": "llava_next",
"name": "Llava Next (1.6)", "name": "Llava Next (1.6)",
@ -1253,6 +1262,24 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS3:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA: if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(

View File

@ -515,9 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=0, index=0,
prefix=( prefix=f"{prefix}.layers.0",
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -533,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaCrossLayer( FlashLlamaCrossLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -546,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -561,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=last_layer_id, index=last_layer_id,
prefix=( prefix=(f"{prefix}.layers.{last_layer_id}"),
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", prefix=f"{prefix}.norm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
@ -629,19 +615,24 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__() super().__init__()
with no_fp8(weights): with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" f"{name}.embed_tokens"
if not prefix if not prefix
else f"{prefix}.model.embed_tokens" else f"{prefix}.{name}.embed_tokens"
), ),
weights=weights, weights=weights,
) )
self.model = FlashLlamaModel(prefix, config, weights) self.model = FlashLlamaModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
suffix = "model.embed_tokens" suffix = "model.embed_tokens"
else: else:
@ -652,11 +643,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if embedding_multiplier is not None: if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier self.embed_tokens.weight.data *= embedding_multiplier
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.{suffix}", prefix,
weights=weights, weights,
) )
# Used in Granite # Used in Granite

View File

@ -0,0 +1,584 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics3 model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics3VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics3VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics3VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics3EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics3VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics3VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics3Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics3EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics3VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics3VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics3Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics3SimpleMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
output_size = config.text_config.hidden_size
proj = nn.Parameter(
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
requires_grad=False,
).to(weights.dtype)
self.proj = nn.Linear(input_size, output_size, bias=False)
self.proj.weight = proj
def forward(self, x):
return self.proj(x)
class Idefics3Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
self.scale_factor = config.scale_factor
def pixel_shuffle(self, x, scale_factor=2):
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3)
x = x.reshape(
bsz,
int(width / scale_factor),
int(height / scale_factor),
embed_dim * (scale_factor**2),
)
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
return x
def forward(self, image_hidden_states):
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class Idefics3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
# since Idefics3 uses the `embed_tokens` for the final prediction
# config.text_config.tie_word_embeddings = True
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics3VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
config.quantize = None
self.connector = Idefics3Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
return FlashLlamaForCausalLM(prefix, config, weights) return FlashLlamaForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "mistral": elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,

View File

@ -1288,7 +1288,7 @@ class FlashCausalLM(Model):
weights_loader=weights_loader, weights_loader=weights_loader,
) )
prefix = "" prefix = None
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from loguru import logger
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -23,6 +24,40 @@ tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>" IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>" IDEFICS2_IMAGE_TOKEN = "<image>"
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image(
*,
image_seq_len: int,
image_rows: int,
image_cols: int,
fake_token_around_image: str,
image_token: str,
global_img_token: str,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
@ -54,10 +89,26 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
if processor.image_processor.do_image_splitting: if processor.image_processor.do_image_splitting:
image_str *= 5 image_str *= 5
return image_str return image_str
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
)
image_str = _prompt_split_image(
image_seq_len=image_seq_len,
image_rows=n_rows,
image_cols=n_cols,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger
log_master( log_master(
logger.info, logger.info,
@ -194,12 +245,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images: if images:
image_inputs = processor.image_processor(images, return_tensors="pt") kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else: else:
image_inputs = None image_inputs = None
batch_inputs = [] batch_tokenized_inputs = []
max_truncation = 0 max_length = 0
image_id = 0 image_id = 0
for r in requests: for r in requests:
full_text = "" full_text = ""
@ -214,16 +274,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_id += 1 image_id += 1
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer(
batch_inputs.append(full_text) full_text,
max_truncation = max(max_truncation, r.truncate) truncation=True,
max_length=r.truncate,
batch_tokenized_inputs = tokenizer( add_special_tokens=r.add_special_tokens,
batch_inputs, )["input_ids"]
truncation=True, max_length = max(max_length, len(input_ids))
max_length=max_truncation, batch_tokenized_inputs.append(input_ids)
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs