fix: create new idefic3 file, simplify logic and adjust llama weight loading

This commit is contained in:
drbh 2024-12-21 00:27:29 +00:00
parent 0d1bf9e983
commit 575d97339c
5 changed files with 1095 additions and 298 deletions

View File

@ -151,6 +151,8 @@ try:
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (

View File

@ -507,6 +507,7 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
base_model = "" if prefix.endswith("text_model") else ".model"
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
@ -515,7 +516,11 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=f"{prefix}.layers.0" if prefix else "model.layers.0",
prefix=(
"model.layers.0"
if not prefix
else f"{prefix}{base_model}.layers.0"
),
config=config,
weights=weights,
)
@ -532,9 +537,9 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaCrossLayer(
index=layer_id,
prefix=(
f"{prefix}.layers.{layer_id}"
if prefix
else f"model.layers.{layer_id}"
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{layer_id}"
),
config=config,
weights=weights,
@ -545,9 +550,9 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaLayer(
index=layer_id,
prefix=(
f"{prefix}.layers.{layer_id}"
if prefix
else f"model.layers.{layer_id}"
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{layer_id}"
),
config=config,
weights=weights,
@ -560,9 +565,9 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaLayer(
index=last_layer_id,
prefix=(
f"{prefix}.layers.{last_layer_id}"
if prefix
else f"model.layers.{last_layer_id}"
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}{base_model}.layers.{last_layer_id}"
),
config=config,
weights=weights,
@ -570,7 +575,7 @@ class FlashLlamaModel(torch.nn.Module):
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm" if prefix else "model.norm",
prefix="model.norm" if not prefix else f"{prefix}{base_model}.norm",
weights=weights,
eps=config.rms_norm_eps,
)
@ -629,18 +634,20 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if config.model_type == "mllama_text_model":
prefix = f"{prefix}.model"
base_model = "" if prefix.endswith("text_model") else ".model"
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(f"{prefix}.embed_tokens" if prefix else "model.embed_tokens"),
prefix=(
"model.embed_tokens"
if not prefix
else f"{prefix}{base_model}.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
suffix = f"model.embed_tokens"
else:
suffix = "lm_head"
@ -649,17 +656,17 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
if config.model_type == "mllama_text_model":
prefix = prefix.replace(".model", "")
suffix = f"{prefix}.{suffix}"
if config.model_type == "granite":
suffix = f"{prefix}.{suffix}"
if not prefix:
head_prefix = suffix
elif prefix.endswith("text_model"):
head_prefix = suffix
else:
head_prefix = f"{prefix}.{suffix}"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix,
prefix=head_prefix,
weights=weights,
)

View File

@ -679,215 +679,6 @@ class Idefics2Connector(nn.Module):
return image_hidden_states
class Idefics3Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = TensorParallelRowLinear.load(
prefix=f"{prefix}.modality_projection.proj",
config=config,
weights=weights,
bias=False,
)
self.scale_factor = config.scale_factor
def pixel_shuffle(self, x, scale_factor=2):
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3)
x = x.reshape(
bsz,
int(width / scale_factor),
int(height / scale_factor),
embed_dim * (scale_factor**2),
)
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
return x
def forward(self, image_hidden_states, attention_mask):
print(image_hidden_states.device, self.modality_projection.linear.weight.device)
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class Idefics3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
vision_config = config.vision_config
self.text_model = load_text_model(
prefix=f"{prefix}.model.text_model" if prefix else "model.text_model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics2VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
config.quantize = None
self.connector = Idefics3Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# TODO: remove when prefill image tokens are handled correctly
# * for now dummy tokens are added instead of the image tokens output byt the vision model
mask_size = (input_ids == self.config.image_token_id).sum().item()
unrolled_image_size = (
image_hidden_states.shape[1] * image_hidden_states.shape[2]
)
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)
if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()

File diff suppressed because it is too large Load Diff

View File

@ -28,70 +28,27 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
def _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def _prompt_single_image(
image_seq_len, fake_token_around_image, image_token, global_img_token
):
"""Prompt with expanded image tokens for a single image."""
return (
f"{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
def get_image_prompt_string(
image_rows,
image_cols,
image_seq_len,
fake_token_around_image,
image_token,
global_img_token,
rows=0,
cols=0,
seq_len=1,
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
img_token=IDEFICS3_IMAGE_TOKEN,
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
tokens = img_token * seq_len
end_token = f"{fake_token}{global_token}{tokens}{fake_token}"
if rows == 0 or cols == 0:
return end_token
grid = "\n".join(
"".join(f"{fake_token}<row_{i+1}_col_{j+1}>{tokens}" for j in range(cols))
for i in range(rows)
)
return f"{grid}\n\n{end_token}"
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
@ -132,12 +89,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
/ (config.scale_factor**2)
)
image_str = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=IDEFICS3_IMAGE_TOKEN,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
rows=n_rows,
cols=n_cols,
seq_len=image_seq_len,
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
img_token=IDEFICS3_IMAGE_TOKEN,
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next":