fix: adjust siglip attention

This commit is contained in:
drbh 2024-05-09 19:13:56 +00:00 committed by Nicolas Patry
parent 23294344c6
commit e13c08f57f
6 changed files with 130 additions and 83 deletions

View File

@ -366,10 +366,6 @@ class FlashGemmaModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
prefix=pvalue,
weights=weights,
# limit embed_tokens.weight size to the config.vocab_size
)
self.embed_tokens.weight = torch.nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
)
# TODO: double check why this is needed

View File

@ -29,29 +29,6 @@ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
)
class PaliGemmaConfig(PretrainedConfig):
model_type = "paligemma"
def from_pretrained(pretrained_model_name_or_path, **kwargs):
vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
)
return GemmaConfig.from_pretrained(
pretrained_model_name_or_path, vision_config=vision_config, **kwargs
)
class VisionConfig(PretrainedConfig):
def __init__(
self,
@ -95,6 +72,80 @@ class VisionConfig(PretrainedConfig):
super().__init__(**kwargs)
class PaliGemmaConfig(PretrainedConfig):
model_type = "paligemma"
def __init__(
self,
text_config: GemmaConfig,
vision_config: VisionConfig,
vocab_size: int = 257152,
image_token_index: int = 256000,
**kwargs,
):
self.text_config = text_config
self.vision_config = vision_config
self.vocab_size = vocab_size
self.image_token_index = image_token_index
self.intermediate_size = text_config.intermediate_size
self.num_hidden_layers = text_config.num_hidden_layers
self.num_key_value_heads = text_config.num_key_value_heads
self.num_attention_heads = text_config.num_attention_heads
super().__init__(**kwargs)
def from_pretrained(pretrained_model_name_or_path, **kwargs):
vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
)
text_config = GemmaConfig.from_pretrained(
pretrained_model_name_or_path,
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
hidden_size=2048,
initializer_range=0.02,
intermediate_size=16384,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_image_tokens=256,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.40.0.dev0",
use_cache=True,
vocab_size=257216,
**kwargs,
)
return PaliGemmaConfig(
text_config=text_config,
vision_config=vision_config,
**kwargs,
)
class FlashPaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
@ -116,8 +167,8 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
self.config = config
self.language_model = load_text_model(
prefix=prefix,
config=config,
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
).to(weights.device, weights.dtype)
self.pad_token_id = (
@ -165,22 +216,18 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
inputs_embeds = self._merge_input_ids_with_image_features(
# NOTE: image_features returns the exact values as transformers
# TODO: correctly merge inputs_embeds with image_features
merged_inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
if input_ids.size(0) != 3000:
import ipdb
# import ipdb
ipdb.set_trace()
## TODO: remove this
## load in values from reference
# tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz")
# inputs_embeds = torch.tensor(
# tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype
# ).squeeze()
# ipdb.set_trace()
pass
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,

View File

@ -122,18 +122,18 @@ class SiglipAttention(nn.Module):
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
)
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=True,
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@ -152,18 +152,10 @@ class SiglipAttention(nn.Module):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
]
* 3,
dim=2,
)
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_size)
query_states = self.q_proj(hidden_states)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
@ -196,7 +188,7 @@ class SiglipAttention(nn.Module):
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_weights, value_states)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError(
@ -277,7 +269,6 @@ class SiglipEncoderLayer(nn.Module):
hidden_states = residual + hidden_states
if output_attentions:
return hidden_states, attn_weights
print(hidden_states[0, 0, :5].tolist())
return hidden_states, None

View File

@ -99,12 +99,6 @@ class BaseFlashGemma(FlashCausalLM):
config.quantize = quantize
config.speculator = speculator
if is_vlm:
config.intermediate_size = config.text_config.get("intermediate_size")
config.num_attention_heads = config.text_config.get("num_attention_heads")
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
config.num_key_value_heads = config.text_config.get("num_key_value_heads")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -116,14 +110,9 @@ class BaseFlashGemma(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
if is_vlm:
num_layers = config.num_hidden_layers
num_kv_heads = config.num_key_value_heads
head_size = config.intermediate_size
else:
num_layers = len(model.model.layers)
num_kv_heads = model.model.num_key_value_heads
head_size = model.model.head_size
super().__init__(
model=model,

View File

@ -23,9 +23,7 @@ class FlashPaliGemma(PaliVlmCausalLM):
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
# TODO: load in the correct processor based on the model_id
"google/siglip-base-patch16-224",
# "google/siglip-so400m-patch14-384",
revision=revision,
trust_remote_code=trust_remote_code,
)
@ -39,7 +37,6 @@ class FlashPaliGemma(PaliVlmCausalLM):
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
prefix="language_model",
)
def get_layer_config(self, model) -> Tuple[int, int, int]:

View File

@ -405,6 +405,8 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
text_inputs = []
image_text_replacements = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
@ -413,6 +415,7 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
text_inputs.append(chunk["content"])
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
@ -427,7 +430,11 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
text_replacement = image_text_replacement(
image_input, config, image_id
)
full_text += text_replacement
image_text_replacements.append(text_replacement)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
@ -436,8 +443,28 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
image_token = tokenizer.get_added_vocab()["<image>"]
# find the index of the first non-image token
for batch in batch_tokenized_inputs:
first_non_image = 0
for i, token in enumerate(batch):
if token != image_token:
first_non_image = i
break
# manually add the bos to the left of the text
batch_tokenized_inputs = [
batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:]
for batch in batch_tokenized_inputs
]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {