mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: expand vlm support and add image token logic and tests
This commit is contained in:
parent
7eeefa3b57
commit
cf29c5b5cd
104
integration-tests/models/test_idefics3.py
Normal file
104
integration-tests/models/test_idefics3.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_idefics3_next_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_idefics3_next(flash_idefics3_next_handle):
|
||||||
|
await flash_idefics3_next_handle.health(300)
|
||||||
|
return flash_idefics3_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics3_next_simple(flash_idefics3_next, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
response = await flash_idefics3_next.generate(
|
||||||
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " A chicken is sitting on a pile of money."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics3_two_images(flash_idefics3_next, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await flash_idefics3_next.generate(
|
||||||
|
f"User:Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 19
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics3_next_all_params(flash_idefics3_next, response_snapshot):
|
||||||
|
response = await flash_idefics3_next.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics3_next_load(
|
||||||
|
flash_idefics3_next, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
chicken = get_chicken()
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_idefics3_next,
|
||||||
|
f"User:Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
generated_texts = [r.generated_text for r in responses]
|
||||||
|
assert generated_texts[0] == " A chicken is sitting on a pile of money."
|
||||||
|
assert len(generated_texts) == 4
|
||||||
|
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -110,6 +110,13 @@ pub struct ClipVisionModel {
|
|||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Idefics3 {
|
||||||
|
pub(crate) vision_encoder_max_image_size: usize,
|
||||||
|
pub(crate) image_seq_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Idefics2 {}
|
pub struct Idefics2 {}
|
||||||
@ -178,6 +185,7 @@ pub enum Config {
|
|||||||
Idefics,
|
Idefics,
|
||||||
Mllama,
|
Mllama,
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
|
Idefics3(Idefics3),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
Granite,
|
Granite,
|
||||||
|
@ -170,6 +170,7 @@ impl TokenizerConfigToken {
|
|||||||
#[serde(tag = "processor_class")]
|
#[serde(tag = "processor_class")]
|
||||||
pub enum HubPreprocessorConfig {
|
pub enum HubPreprocessorConfig {
|
||||||
Idefics2Processor(Idefics2Preprocessor),
|
Idefics2Processor(Idefics2Preprocessor),
|
||||||
|
Idefics3Processor(Idefics2Preprocessor),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubPreprocessorConfig {
|
impl HubPreprocessorConfig {
|
||||||
|
@ -151,6 +151,7 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
Idefics2ForConditionalGeneration,
|
Idefics2ForConditionalGeneration,
|
||||||
|
Idefics3ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
Qwen2VLForConditionalGeneration,
|
Qwen2VLForConditionalGeneration,
|
||||||
@ -188,6 +189,12 @@ class ModelType(enum.Enum):
|
|||||||
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
IDEFICS3 = {
|
||||||
|
"type": "idefics3",
|
||||||
|
"name": "Idefics 3",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
LLAVA_NEXT = {
|
LLAVA_NEXT = {
|
||||||
"type": "llava_next",
|
"type": "llava_next",
|
||||||
"name": "Llava Next (1.6)",
|
"name": "Llava Next (1.6)",
|
||||||
@ -1253,6 +1260,23 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == IDEFICS3:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics3ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
|
@ -515,9 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=0,
|
index=0,
|
||||||
prefix=(
|
prefix=("model.layers.0" if not prefix else f"{prefix}.layers.0"),
|
||||||
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -564,7 +562,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
prefix=(
|
prefix=(
|
||||||
f"model.layers.{last_layer_id}"
|
f"model.layers.{last_layer_id}"
|
||||||
if not prefix
|
if not prefix
|
||||||
else f"{prefix}.model.layers.{last_layer_id}"
|
else f"{prefix}.layers.{last_layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -572,7 +570,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
prefix="model.norm" if not prefix else f"{prefix}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
@ -635,9 +633,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens"
|
"model.embed_tokens" if not prefix else f"{prefix}.embed_tokens"
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.embed_tokens"
|
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -655,7 +651,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
prefix=suffix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -679,6 +679,281 @@ class Idefics2Connector(nn.Module):
|
|||||||
return image_hidden_states
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3Connector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.modality_projection = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.modality_projection.proj",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.scale_factor = config.scale_factor
|
||||||
|
|
||||||
|
def pixel_shuffle(self, x, scale_factor=2):
|
||||||
|
bsz, seq, embed_dim = x.size()
|
||||||
|
height = width = int(seq**0.5)
|
||||||
|
x = x.view(bsz, height, width, embed_dim)
|
||||||
|
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
x = x.reshape(
|
||||||
|
bsz,
|
||||||
|
int(width / scale_factor),
|
||||||
|
int(height / scale_factor),
|
||||||
|
embed_dim * (scale_factor**2),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, image_hidden_states, attention_mask):
|
||||||
|
print(image_hidden_states.device, self.modality_projection.linear.weight.device)
|
||||||
|
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
||||||
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverConfig:
|
||||||
|
def __init__(self, config_dict):
|
||||||
|
self._name_or_path = config_dict.get("_name_or_path", "")
|
||||||
|
self.add_cross_attention = config_dict.get("add_cross_attention", False)
|
||||||
|
self.architectures = config_dict.get("architectures", None)
|
||||||
|
self.attention_dropout = config_dict.get("attention_dropout", 0.0)
|
||||||
|
self.bad_words_ids = config_dict.get("bad_words_ids", None)
|
||||||
|
self.begin_suppress_tokens = config_dict.get("begin_suppress_tokens", None)
|
||||||
|
self.bos_token_id = config_dict.get("bos_token_id", None)
|
||||||
|
self.chunk_size_feed_forward = config_dict.get("chunk_size_feed_forward", 0)
|
||||||
|
self.cross_attention_hidden_size = config_dict.get(
|
||||||
|
"cross_attention_hidden_size", None
|
||||||
|
)
|
||||||
|
self.decoder_start_token_id = config_dict.get("decoder_start_token_id", None)
|
||||||
|
self.diversity_penalty = config_dict.get("diversity_penalty", 0.0)
|
||||||
|
self.do_sample = config_dict.get("do_sample", False)
|
||||||
|
self.early_stopping = config_dict.get("early_stopping", False)
|
||||||
|
self.encoder_no_repeat_ngram_size = config_dict.get(
|
||||||
|
"encoder_no_repeat_ngram_size", 0
|
||||||
|
)
|
||||||
|
self.eos_token_id = config_dict.get("eos_token_id", None)
|
||||||
|
self.exponential_decay_length_penalty = config_dict.get(
|
||||||
|
"exponential_decay_length_penalty", None
|
||||||
|
)
|
||||||
|
self.finetuning_task = config_dict.get("finetuning_task", None)
|
||||||
|
self.forced_bos_token_id = config_dict.get("forced_bos_token_id", None)
|
||||||
|
self.forced_eos_token_id = config_dict.get("forced_eos_token_id", None)
|
||||||
|
self.hidden_act = config_dict.get("hidden_act", "silu")
|
||||||
|
self.id2label = config_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"})
|
||||||
|
self.is_decoder = config_dict.get("is_decoder", False)
|
||||||
|
self.is_encoder_decoder = config_dict.get("is_encoder_decoder", False)
|
||||||
|
self.label2id = config_dict.get("label2id", {"LABEL_0": 0, "LABEL_1": 1})
|
||||||
|
self.length_penalty = config_dict.get("length_penalty", 1.0)
|
||||||
|
self.max_length = config_dict.get("max_length", 20)
|
||||||
|
self.min_length = config_dict.get("min_length", 0)
|
||||||
|
self.model_type = config_dict.get("model_type", "idefics3")
|
||||||
|
self.no_repeat_ngram_size = config_dict.get("no_repeat_ngram_size", 0)
|
||||||
|
self.num_beam_groups = config_dict.get("num_beam_groups", 1)
|
||||||
|
self.num_beams = config_dict.get("num_beams", 1)
|
||||||
|
self.num_key_value_heads = config_dict.get("num_key_value_heads", 1)
|
||||||
|
self.num_return_sequences = config_dict.get("num_return_sequences", 1)
|
||||||
|
self.output_attentions = config_dict.get("output_attentions", False)
|
||||||
|
self.output_hidden_states = config_dict.get("output_hidden_states", False)
|
||||||
|
self.output_scores = config_dict.get("output_scores", False)
|
||||||
|
self.pad_token_id = config_dict.get("pad_token_id", 128002)
|
||||||
|
self.prefix = config_dict.get("prefix", None)
|
||||||
|
self.problem_type = config_dict.get("problem_type", None)
|
||||||
|
self.pruned_heads = config_dict.get("pruned_heads", {})
|
||||||
|
self.qk_layer_norms_perceiver = config_dict.get(
|
||||||
|
"qk_layer_norms_perceiver", False
|
||||||
|
)
|
||||||
|
self.remove_invalid_values = config_dict.get("remove_invalid_values", False)
|
||||||
|
self.repetition_penalty = config_dict.get("repetition_penalty", 1.0)
|
||||||
|
self.resampler_depth = config_dict.get("resampler_depth", 6)
|
||||||
|
self.resampler_head_dim = config_dict.get("resampler_head_dim", 96)
|
||||||
|
self.resampler_n_heads = config_dict.get("resampler_n_heads", 16)
|
||||||
|
self.resampler_n_latents = config_dict.get("resampler_n_latents", 64)
|
||||||
|
self.return_dict = config_dict.get("return_dict", True)
|
||||||
|
self.return_dict_in_generate = config_dict.get("return_dict_in_generate", False)
|
||||||
|
self.sep_token_id = config_dict.get("sep_token_id", None)
|
||||||
|
self.suppress_tokens = config_dict.get("suppress_tokens", None)
|
||||||
|
self.task_specific_params = config_dict.get("task_specific_params", None)
|
||||||
|
self.temperature = config_dict.get("temperature", 1.0)
|
||||||
|
self.tf_legacy_loss = config_dict.get("tf_legacy_loss", False)
|
||||||
|
self.tie_encoder_decoder = config_dict.get("tie_encoder_decoder", False)
|
||||||
|
self.tie_word_embeddings = config_dict.get("tie_word_embeddings", True)
|
||||||
|
self.tokenizer_class = config_dict.get("tokenizer_class", None)
|
||||||
|
self.top_k = config_dict.get("top_k", 50)
|
||||||
|
self.top_p = config_dict.get("top_p", 1.0)
|
||||||
|
self.torch_dtype = config_dict.get("torch_dtype", None)
|
||||||
|
self.torchscript = config_dict.get("torchscript", False)
|
||||||
|
self.transformers_version = config_dict.get("transformers_version", "4.43.2")
|
||||||
|
self.typical_p = config_dict.get("typical_p", 1.0)
|
||||||
|
self.use_bfloat16 = config_dict.get("use_bfloat16", False)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
name="text_model",
|
||||||
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
# The vision and connector models are not quantized.
|
||||||
|
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
||||||
|
self.vision_model = Idefics2VisionTransformer(
|
||||||
|
prefix=(
|
||||||
|
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||||
|
),
|
||||||
|
config=vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.quantize = None
|
||||||
|
self.connector = Idefics3Connector(
|
||||||
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
config.text_config.perceiver_config = PerceiverConfig(
|
||||||
|
config_dict=config.text_config.perceiver_config
|
||||||
|
)
|
||||||
|
self.image_seq_len = config.text_config.perceiver_config.resampler_n_latents
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
# Unused here
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(
|
||||||
|
dtype=self.dtype
|
||||||
|
) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||||
|
)
|
||||||
|
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
# TODO: finish implementing the image token replacement
|
||||||
|
|
||||||
|
# inputs_embeds = self.inputs_merger(
|
||||||
|
# input_ids=input_ids,
|
||||||
|
# inputs_embeds=inputs_embeds,
|
||||||
|
# image_hidden_states=image_hidden_states,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# num_images, _, vision_hidden_size = image_hidden_states.shape
|
||||||
|
# special_image_token_mask = input_ids == self.image_token_id
|
||||||
|
# new_inputs_embeds = inputs_embeds.clone()
|
||||||
|
# reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size).to(
|
||||||
|
# inputs_embeds.dtype
|
||||||
|
# ) # cast to the dtype of the input_embeds to support quantized models
|
||||||
|
# new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
||||||
|
# inputs_embeds = new_inputs_embeds
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class Idefics2ForConditionalGeneration(nn.Module):
|
class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -707,7 +982,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config.quantize = None
|
config.quantize = None
|
||||||
self.connector = Idefics2Connector(
|
self.connector = Idefics3Connector(
|
||||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
return FlashLlamaForCausalLM(f"{prefix}.text_model", config, weights)
|
||||||
elif config.model_type == "mistral":
|
elif config.model_type == "mistral":
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
|
@ -54,6 +54,10 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
if processor.image_processor.do_image_splitting:
|
if processor.image_processor.do_image_splitting:
|
||||||
image_str *= 5
|
image_str *= 5
|
||||||
return image_str
|
return image_str
|
||||||
|
if config.model_type == "idefics3":
|
||||||
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN}{IDEFICS2_FAKE_TOKEN}"
|
||||||
|
image_str = ""
|
||||||
|
return image_str
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
@ -288,6 +292,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
self.batch_class = batch_class
|
self.batch_class = batch_class
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
Loading…
Reference in New Issue
Block a user