mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: create new idefic3 file, simplify logic and adjust llama weight loading
This commit is contained in:
parent
0d1bf9e983
commit
575d97339c
@ -151,6 +151,8 @@ try:
|
|||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
Idefics2ForConditionalGeneration,
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics3 import (
|
||||||
Idefics3ForConditionalGeneration,
|
Idefics3ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
|
@ -507,6 +507,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
|
base_model = "" if prefix.endswith("text_model") else ".model"
|
||||||
|
|
||||||
# Skip fp8 quant for first and last layers
|
# Skip fp8 quant for first and last layers
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
@ -515,7 +516,11 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=0,
|
index=0,
|
||||||
prefix=f"{prefix}.layers.0" if prefix else "model.layers.0",
|
prefix=(
|
||||||
|
"model.layers.0"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}{base_model}.layers.0"
|
||||||
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -532,9 +537,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
FlashLlamaCrossLayer(
|
FlashLlamaCrossLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"{prefix}.layers.{layer_id}"
|
f"model.layers.{layer_id}"
|
||||||
if prefix
|
if not prefix
|
||||||
else f"model.layers.{layer_id}"
|
else f"{prefix}{base_model}.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -545,9 +550,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"{prefix}.layers.{layer_id}"
|
f"model.layers.{layer_id}"
|
||||||
if prefix
|
if not prefix
|
||||||
else f"model.layers.{layer_id}"
|
else f"{prefix}{base_model}.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -560,9 +565,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=last_layer_id,
|
index=last_layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
f"{prefix}.layers.{last_layer_id}"
|
f"model.layers.{last_layer_id}"
|
||||||
if prefix
|
if not prefix
|
||||||
else f"model.layers.{last_layer_id}"
|
else f"{prefix}{base_model}.layers.{last_layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -570,7 +575,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.norm" if prefix else "model.norm",
|
prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
@ -629,18 +634,20 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
base_model = "" if prefix.endswith("text_model") else ".model"
|
||||||
if config.model_type == "mllama_text_model":
|
|
||||||
prefix = f"{prefix}.model"
|
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"),
|
prefix=(
|
||||||
|
"model.embed_tokens"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}{base_model}.embed_tokens"
|
||||||
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(prefix, config, weights)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = f"model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
suffix = "lm_head"
|
suffix = "lm_head"
|
||||||
|
|
||||||
@ -649,17 +656,17 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
if embedding_multiplier is not None:
|
if embedding_multiplier is not None:
|
||||||
self.embed_tokens.weight.data *= embedding_multiplier
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
|
||||||
if config.model_type == "mllama_text_model":
|
if not prefix:
|
||||||
prefix = prefix.replace(".model", "")
|
head_prefix = suffix
|
||||||
suffix = f"{prefix}.{suffix}"
|
elif prefix.endswith("text_model"):
|
||||||
|
head_prefix = suffix
|
||||||
if config.model_type == "granite":
|
else:
|
||||||
suffix = f"{prefix}.{suffix}"
|
head_prefix = f"{prefix}.{suffix}"
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix=suffix,
|
prefix=head_prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -679,215 +679,6 @@ class Idefics2Connector(nn.Module):
|
|||||||
return image_hidden_states
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Idefics3Connector(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.modality_projection = TensorParallelRowLinear.load(
|
|
||||||
prefix=f"{prefix}.modality_projection.proj",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.scale_factor = config.scale_factor
|
|
||||||
|
|
||||||
def pixel_shuffle(self, x, scale_factor=2):
|
|
||||||
bsz, seq, embed_dim = x.size()
|
|
||||||
height = width = int(seq**0.5)
|
|
||||||
x = x.view(bsz, height, width, embed_dim)
|
|
||||||
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
|
|
||||||
x = x.permute(0, 2, 1, 3)
|
|
||||||
x = x.reshape(
|
|
||||||
bsz,
|
|
||||||
int(width / scale_factor),
|
|
||||||
int(height / scale_factor),
|
|
||||||
embed_dim * (scale_factor**2),
|
|
||||||
)
|
|
||||||
x = x.permute(0, 2, 1, 3)
|
|
||||||
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, image_hidden_states, attention_mask):
|
|
||||||
print(image_hidden_states.device, self.modality_projection.linear.weight.device)
|
|
||||||
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
|
||||||
image_hidden_states = self.modality_projection(image_hidden_states)
|
|
||||||
return image_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ForConditionalGeneration(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
|
||||||
super().__init__()
|
|
||||||
config.vision_config.quantize = None
|
|
||||||
config.vision_config.speculator = config.speculator
|
|
||||||
config.text_config.quantize = config.quantize
|
|
||||||
config.text_config.speculator = config.speculator
|
|
||||||
|
|
||||||
vision_config = config.vision_config
|
|
||||||
self.text_model = load_text_model(
|
|
||||||
prefix=f"{prefix}.model.text_model" if prefix else "model.text_model",
|
|
||||||
config=config.text_config,
|
|
||||||
weights=weights,
|
|
||||||
name="text_model",
|
|
||||||
)
|
|
||||||
self.dtype = weights.dtype
|
|
||||||
|
|
||||||
# The vision and connector models are not quantized.
|
|
||||||
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
|
||||||
self.vision_model = Idefics2VisionTransformer(
|
|
||||||
prefix=(
|
|
||||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
|
||||||
),
|
|
||||||
config=vision_config,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quantize = None
|
|
||||||
self.connector = Idefics3Connector(
|
|
||||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.image_token_id = config.image_token_id
|
|
||||||
self.pad_token_id = (
|
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
|
||||||
)
|
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
image_features: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
||||||
# mask = input_ids == self.config.image_token_index
|
|
||||||
mask = input_ids == self.config.image_token_id
|
|
||||||
# Let's pray we have enabled enough slots !
|
|
||||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
max_s: int,
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
|
||||||
pixel_values: torch.FloatTensor = None,
|
|
||||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
# Unused here
|
|
||||||
image_sizes: Optional[torch.Tensor] = None,
|
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
image_indices=None,
|
|
||||||
):
|
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
|
||||||
all_states = []
|
|
||||||
all_pixel_values = pixel_values
|
|
||||||
all_pixel_mask = pixel_attention_mask
|
|
||||||
for i in range(batch_size):
|
|
||||||
pixel_values = all_pixel_values.to(
|
|
||||||
dtype=self.dtype
|
|
||||||
) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values[i : i + 1]
|
|
||||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(
|
|
||||||
dim=(-1, -2, -3)
|
|
||||||
) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(
|
|
||||||
pixel_values.size(0),
|
|
||||||
pixel_values.size(2),
|
|
||||||
pixel_values.size(3),
|
|
||||||
),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask/pP p
|
|
||||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[
|
|
||||||
real_images_inds
|
|
||||||
].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(
|
|
||||||
dimension=1, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(
|
|
||||||
dimension=2, size=patch_size, step=patch_size
|
|
||||||
)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states,
|
|
||||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
||||||
)
|
|
||||||
|
|
||||||
all_states.append(image_hidden_states)
|
|
||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
|
||||||
# TODO: remove when prefill image tokens are handled correctly
|
|
||||||
# * for now dummy tokens are added instead of the image tokens output byt the vision model
|
|
||||||
mask_size = (input_ids == self.config.image_token_id).sum().item()
|
|
||||||
unrolled_image_size = (
|
|
||||||
image_hidden_states.shape[1] * image_hidden_states.shape[2]
|
|
||||||
)
|
|
||||||
diff = mask_size - unrolled_image_size
|
|
||||||
if diff > 0:
|
|
||||||
print(
|
|
||||||
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if mask_size == unrolled_image_size:
|
|
||||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
||||||
input_ids, inputs_embeds, image_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.text_model.model(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
seqlen=seqlen,
|
|
||||||
max_s=max_s,
|
|
||||||
true_max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
|
||||||
adapter_data=adapter_data,
|
|
||||||
)
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
|
||||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2ForConditionalGeneration(nn.Module):
|
class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
1040
server/text_generation_server/models/custom_modeling/idefics3.py
Normal file
1040
server/text_generation_server/models/custom_modeling/idefics3.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -28,70 +28,27 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
|||||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||||
|
|
||||||
|
|
||||||
def _prompt_split_image(
|
|
||||||
image_seq_len,
|
|
||||||
image_rows,
|
|
||||||
image_cols,
|
|
||||||
fake_token_around_image,
|
|
||||||
image_token,
|
|
||||||
global_img_token,
|
|
||||||
):
|
|
||||||
"""Prompt with expanded image tokens for when the image is split into patches."""
|
|
||||||
text_split_images = ""
|
|
||||||
for n_h in range(image_rows):
|
|
||||||
for n_w in range(image_cols):
|
|
||||||
text_split_images += (
|
|
||||||
f"{fake_token_around_image}"
|
|
||||||
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
|
||||||
+ f"{image_token}" * image_seq_len
|
|
||||||
)
|
|
||||||
text_split_images += "\n"
|
|
||||||
|
|
||||||
text_split_images += (
|
|
||||||
f"\n{fake_token_around_image}"
|
|
||||||
+ f"{global_img_token}"
|
|
||||||
+ f"{image_token}" * image_seq_len
|
|
||||||
+ f"{fake_token_around_image}"
|
|
||||||
)
|
|
||||||
return text_split_images
|
|
||||||
|
|
||||||
|
|
||||||
def _prompt_single_image(
|
|
||||||
image_seq_len, fake_token_around_image, image_token, global_img_token
|
|
||||||
):
|
|
||||||
"""Prompt with expanded image tokens for a single image."""
|
|
||||||
return (
|
|
||||||
f"{fake_token_around_image}"
|
|
||||||
+ f"{global_img_token}"
|
|
||||||
+ f"{image_token}" * image_seq_len
|
|
||||||
+ f"{fake_token_around_image}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_prompt_string(
|
def get_image_prompt_string(
|
||||||
image_rows,
|
rows=0,
|
||||||
image_cols,
|
cols=0,
|
||||||
image_seq_len,
|
seq_len=1,
|
||||||
fake_token_around_image,
|
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||||
image_token,
|
img_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
global_img_token,
|
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
):
|
):
|
||||||
if image_rows == 0 and image_cols == 0:
|
tokens = img_token * seq_len
|
||||||
return _prompt_single_image(
|
end_token = f"{fake_token}{global_token}{tokens}{fake_token}"
|
||||||
image_seq_len,
|
|
||||||
fake_token_around_image=fake_token_around_image,
|
if rows == 0 or cols == 0:
|
||||||
image_token=image_token,
|
return end_token
|
||||||
global_img_token=global_img_token,
|
|
||||||
)
|
grid = "\n".join(
|
||||||
return _prompt_split_image(
|
"".join(f"{fake_token}<row_{i+1}_col_{j+1}>{tokens}" for j in range(cols))
|
||||||
image_seq_len,
|
for i in range(rows)
|
||||||
image_rows,
|
|
||||||
image_cols,
|
|
||||||
fake_token_around_image,
|
|
||||||
image_token,
|
|
||||||
global_img_token,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return f"{grid}\n\n{end_token}"
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
@ -132,12 +89,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
/ (config.scale_factor**2)
|
/ (config.scale_factor**2)
|
||||||
)
|
)
|
||||||
image_str = get_image_prompt_string(
|
image_str = get_image_prompt_string(
|
||||||
n_rows,
|
rows=n_rows,
|
||||||
n_cols,
|
cols=n_cols,
|
||||||
image_seq_len,
|
seq_len=image_seq_len,
|
||||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||||
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
img_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
)
|
)
|
||||||
return image_str
|
return image_str
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
|
Loading…
Reference in New Issue
Block a user