mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust siglip attention
This commit is contained in:
parent
23294344c6
commit
e13c08f57f
@ -366,10 +366,6 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=pvalue,
|
||||
weights=weights,
|
||||
# limit embed_tokens.weight size to the config.vocab_size
|
||||
)
|
||||
self.embed_tokens.weight = torch.nn.Parameter(
|
||||
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
|
||||
)
|
||||
|
||||
# TODO: double check why this is needed
|
||||
|
@ -29,29 +29,6 @@ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaConfig(PretrainedConfig):
|
||||
model_type = "paligemma"
|
||||
|
||||
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
||||
vision_config = VisionConfig(
|
||||
hidden_size=1152,
|
||||
intermediate_size=4304,
|
||||
model_type="siglip_vision_model",
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=27,
|
||||
num_image_tokens=256,
|
||||
patch_size=14,
|
||||
projection_dim=2048,
|
||||
projector_hidden_act="gelu_fast",
|
||||
vision_use_head=False,
|
||||
vocab_size=257152,
|
||||
)
|
||||
|
||||
return GemmaConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, vision_config=vision_config, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class VisionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
@ -95,6 +72,80 @@ class VisionConfig(PretrainedConfig):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class PaliGemmaConfig(PretrainedConfig):
|
||||
model_type = "paligemma"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config: GemmaConfig,
|
||||
vision_config: VisionConfig,
|
||||
vocab_size: int = 257152,
|
||||
image_token_index: int = 256000,
|
||||
**kwargs,
|
||||
):
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.image_token_index = image_token_index
|
||||
|
||||
self.intermediate_size = text_config.intermediate_size
|
||||
self.num_hidden_layers = text_config.num_hidden_layers
|
||||
self.num_key_value_heads = text_config.num_key_value_heads
|
||||
self.num_attention_heads = text_config.num_attention_heads
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
||||
vision_config = VisionConfig(
|
||||
hidden_size=1152,
|
||||
intermediate_size=4304,
|
||||
model_type="siglip_vision_model",
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=27,
|
||||
num_image_tokens=256,
|
||||
patch_size=14,
|
||||
projection_dim=2048,
|
||||
projector_hidden_act="gelu_fast",
|
||||
vision_use_head=False,
|
||||
vocab_size=257152,
|
||||
)
|
||||
|
||||
text_config = GemmaConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
hidden_size=2048,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=16384,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma",
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=18,
|
||||
num_image_tokens=256,
|
||||
num_key_value_heads=1,
|
||||
pad_token_id=0,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_theta=10000.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.40.0.dev0",
|
||||
use_cache=True,
|
||||
vocab_size=257216,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return PaliGemmaConfig(
|
||||
text_config=text_config,
|
||||
vision_config=vision_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
@ -116,8 +167,8 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.language_model = load_text_model(
|
||||
prefix=prefix,
|
||||
config=config,
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
).to(weights.device, weights.dtype)
|
||||
self.pad_token_id = (
|
||||
@ -165,22 +216,18 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
# NOTE: image_features returns the exact values as transformers
|
||||
|
||||
# TODO: correctly merge inputs_embeds with image_features
|
||||
merged_inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids
|
||||
)
|
||||
|
||||
if input_ids.size(0) != 3000:
|
||||
import ipdb
|
||||
# import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
## TODO: remove this
|
||||
## load in values from reference
|
||||
# tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz")
|
||||
# inputs_embeds = torch.tensor(
|
||||
# tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
||||
# ).squeeze()
|
||||
# ipdb.set_trace()
|
||||
pass
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -122,18 +122,18 @@ class SiglipAttention(nn.Module):
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
|
||||
)
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
@ -152,18 +152,10 @@ class SiglipAttention(nn.Module):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
]
|
||||
* 3,
|
||||
dim=2,
|
||||
)
|
||||
key_states = self._shape(key_states, -1, bsz)
|
||||
value_states = self._shape(value_states, -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
@ -196,7 +188,7 @@ class SiglipAttention(nn.Module):
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.bmm(attn_weights, value_states)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||
raise ValueError(
|
||||
@ -277,7 +269,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
hidden_states = residual + hidden_states
|
||||
if output_attentions:
|
||||
return hidden_states, attn_weights
|
||||
print(hidden_states[0, 0, :5].tolist())
|
||||
return hidden_states, None
|
||||
|
||||
|
||||
|
@ -99,12 +99,6 @@ class BaseFlashGemma(FlashCausalLM):
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
if is_vlm:
|
||||
config.intermediate_size = config.text_config.get("intermediate_size")
|
||||
config.num_attention_heads = config.text_config.get("num_attention_heads")
|
||||
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
|
||||
config.num_key_value_heads = config.text_config.get("num_key_value_heads")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
@ -116,14 +110,9 @@ class BaseFlashGemma(FlashCausalLM):
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
if is_vlm:
|
||||
num_layers = config.num_hidden_layers
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
head_size = config.intermediate_size
|
||||
else:
|
||||
num_layers = len(model.model.layers)
|
||||
num_kv_heads = model.model.num_key_value_heads
|
||||
head_size = model.model.head_size
|
||||
num_layers = config.num_hidden_layers
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
head_size = config.intermediate_size
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
@ -23,9 +23,7 @@ class FlashPaliGemma(PaliVlmCausalLM):
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
# TODO: load in the correct processor based on the model_id
|
||||
"google/siglip-base-patch16-224",
|
||||
# "google/siglip-so400m-patch14-384",
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -39,7 +37,6 @@ class FlashPaliGemma(PaliVlmCausalLM):
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
prefix="language_model",
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
|
@ -405,6 +405,8 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
text_inputs = []
|
||||
image_text_replacements = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
@ -413,6 +415,7 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
text_inputs.append(chunk["content"])
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
@ -427,7 +430,11 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
text_replacement = image_text_replacement(
|
||||
image_input, config, image_id
|
||||
)
|
||||
full_text += text_replacement
|
||||
image_text_replacements.append(text_replacement)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
@ -436,8 +443,28 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
batch_inputs,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_special_tokens=False,
|
||||
)["input_ids"]
|
||||
|
||||
image_token = tokenizer.get_added_vocab()["<image>"]
|
||||
|
||||
# find the index of the first non-image token
|
||||
for batch in batch_tokenized_inputs:
|
||||
first_non_image = 0
|
||||
for i, token in enumerate(batch):
|
||||
if token != image_token:
|
||||
first_non_image = i
|
||||
break
|
||||
|
||||
# manually add the bos to the left of the text
|
||||
batch_tokenized_inputs = [
|
||||
batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:]
|
||||
for batch in batch_tokenized_inputs
|
||||
]
|
||||
|
||||
if image_inputs:
|
||||
image_input = image_inputs[0]
|
||||
new_image_inputs = {
|
||||
|
Loading…
Reference in New Issue
Block a user