diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 0954d857..816fb196 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -366,10 +366,6 @@ class FlashGemmaModel(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( prefix=pvalue, weights=weights, - # limit embed_tokens.weight size to the config.vocab_size - ) - self.embed_tokens.weight = torch.nn.Parameter( - self.embed_tokens.weight[: config.vocab_size, : config.hidden_size] ) # TODO: double check why this is needed diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 2cf51ea1..3e33032a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -29,29 +29,6 @@ from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( ) -class PaliGemmaConfig(PretrainedConfig): - model_type = "paligemma" - - def from_pretrained(pretrained_model_name_or_path, **kwargs): - vision_config = VisionConfig( - hidden_size=1152, - intermediate_size=4304, - model_type="siglip_vision_model", - num_attention_heads=16, - num_hidden_layers=27, - num_image_tokens=256, - patch_size=14, - projection_dim=2048, - projector_hidden_act="gelu_fast", - vision_use_head=False, - vocab_size=257152, - ) - - return GemmaConfig.from_pretrained( - pretrained_model_name_or_path, vision_config=vision_config, **kwargs - ) - - class VisionConfig(PretrainedConfig): def __init__( self, @@ -95,6 +72,80 @@ class VisionConfig(PretrainedConfig): super().__init__(**kwargs) +class PaliGemmaConfig(PretrainedConfig): + model_type = "paligemma" + + def __init__( + self, + text_config: GemmaConfig, + vision_config: VisionConfig, + vocab_size: int = 257152, + image_token_index: int = 256000, + **kwargs, + ): + self.text_config = text_config + self.vision_config = vision_config + + self.vocab_size = vocab_size + self.image_token_index = image_token_index + + self.intermediate_size = text_config.intermediate_size + self.num_hidden_layers = text_config.num_hidden_layers + self.num_key_value_heads = text_config.num_key_value_heads + self.num_attention_heads = text_config.num_attention_heads + + super().__init__(**kwargs) + + def from_pretrained(pretrained_model_name_or_path, **kwargs): + vision_config = VisionConfig( + hidden_size=1152, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + num_image_tokens=256, + patch_size=14, + projection_dim=2048, + projector_hidden_act="gelu_fast", + vision_use_head=False, + vocab_size=257152, + ) + + text_config = GemmaConfig.from_pretrained( + pretrained_model_name_or_path, + attention_bias=False, + attention_dropout=0.0, + bos_token_id=2, + eos_token_id=1, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + hidden_size=2048, + initializer_range=0.02, + intermediate_size=16384, + max_position_embeddings=8192, + model_type="gemma", + num_attention_heads=8, + num_hidden_layers=18, + num_image_tokens=256, + num_key_value_heads=1, + pad_token_id=0, + rms_norm_eps=1e-06, + rope_theta=10000.0, + torch_dtype="float32", + transformers_version="4.40.0.dev0", + use_cache=True, + vocab_size=257216, + **kwargs, + ) + + return PaliGemmaConfig( + text_config=text_config, + vision_config=vision_config, + **kwargs, + ) + + class FlashPaliGemmaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -116,8 +167,8 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): self.config = config self.language_model = load_text_model( - prefix=prefix, - config=config, + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, weights=weights, ).to(weights.device, weights.dtype) self.pad_token_id = ( @@ -165,22 +216,18 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features / (self.config.hidden_size**0.5) - inputs_embeds = self._merge_input_ids_with_image_features( + # NOTE: image_features returns the exact values as transformers + + # TODO: correctly merge inputs_embeds with image_features + merged_inputs_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) if input_ids.size(0) != 3000: - import ipdb + # import ipdb - ipdb.set_trace() - - ## TODO: remove this - ## load in values from reference - # tensor = torch.load("../../new-model-addition-palma/inputs_embeds.npz") - # inputs_embeds = torch.tensor( - # tensor, device=inputs_embeds.device, dtype=inputs_embeds.dtype - # ).squeeze() + # ipdb.set_trace() + pass hidden_states = self.language_model.model( inputs_embeds=inputs_embeds, diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index ad3aad45..b1f6a192 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -122,18 +122,18 @@ class SiglipAttention(nn.Module): self.embed_dim = self.embed_dim // weights.process_group.size() self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=True, + + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=True + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=True + ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=True ) self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=True, + config, prefix=f"{prefix}.out_proj", weights=weights, bias=True ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -152,18 +152,10 @@ class SiglipAttention(nn.Module): """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() - qkv = self.qkv(hidden_states) - query_states, key_states, value_states = qkv.split( - [ - self.head_size * self.num_heads, - ] - * 3, - dim=2, - ) - key_states = self._shape(key_states, -1, bsz) - value_states = self._shape(value_states, -1, bsz) - - proj_shape = (bsz * self.num_heads, -1, self.head_size) + query_states = self.q_proj(hidden_states) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -196,7 +188,7 @@ class SiglipAttention(nn.Module): attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) - attn_output = torch.bmm(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): raise ValueError( @@ -277,7 +269,6 @@ class SiglipEncoderLayer(nn.Module): hidden_states = residual + hidden_states if output_attentions: return hidden_states, attn_weights - print(hidden_states[0, 0, :5].tolist()) return hidden_states, None diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index d60d49de..03b7bac3 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -99,12 +99,6 @@ class BaseFlashGemma(FlashCausalLM): config.quantize = quantize config.speculator = speculator - if is_vlm: - config.intermediate_size = config.text_config.get("intermediate_size") - config.num_attention_heads = config.text_config.get("num_attention_heads") - config.num_hidden_layers = config.text_config.get("num_hidden_layers") - config.num_key_value_heads = config.text_config.get("num_key_value_heads") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -116,14 +110,9 @@ class BaseFlashGemma(FlashCausalLM): torch.distributed.barrier(group=self.process_group) - if is_vlm: - num_layers = config.num_hidden_layers - num_kv_heads = config.num_key_value_heads - head_size = config.intermediate_size - else: - num_layers = len(model.model.layers) - num_kv_heads = model.model.num_key_value_heads - head_size = model.model.head_size + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + head_size = config.intermediate_size super().__init__( model=model, diff --git a/server/text_generation_server/models/flash_pali_gemma.py b/server/text_generation_server/models/flash_pali_gemma.py index e3ab8bbf..d8b7f5ac 100644 --- a/server/text_generation_server/models/flash_pali_gemma.py +++ b/server/text_generation_server/models/flash_pali_gemma.py @@ -23,9 +23,7 @@ class FlashPaliGemma(PaliVlmCausalLM): trust_remote_code: bool = False, ): self.processor = AutoProcessor.from_pretrained( - # TODO: load in the correct processor based on the model_id "google/siglip-base-patch16-224", - # "google/siglip-so400m-patch14-384", revision=revision, trust_remote_code=trust_remote_code, ) @@ -39,7 +37,6 @@ class FlashPaliGemma(PaliVlmCausalLM): use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - prefix="language_model", ) def get_layer_config(self, model) -> Tuple[int, int, int]: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index b2dbd0a3..a539b6ad 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -405,6 +405,8 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch): def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): batch_inputs = [] image_inputs = [] + text_inputs = [] + image_text_replacements = [] max_truncation = 0 for r in requests: chunks = split(r.inputs) @@ -413,6 +415,7 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch): for chunk in chunks: if chunk["type"] == "text": full_text += chunk["content"] + text_inputs.append(chunk["content"]) elif chunk["type"] == "image": image = chunk["content"] # Should never receive URLs anymore, processing should be done @@ -427,7 +430,11 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch): "Cannot process input image not starting with data:" ) image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement(image_input, config, image_id) + text_replacement = image_text_replacement( + image_input, config, image_id + ) + full_text += text_replacement + image_text_replacements.append(text_replacement) image_inputs.append(image_input) else: raise RuntimeError(f"Invalid chunk type {chunk['type']}") @@ -436,8 +443,28 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch): max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=False, )["input_ids"] + + image_token = tokenizer.get_added_vocab()[""] + + # find the index of the first non-image token + for batch in batch_tokenized_inputs: + first_non_image = 0 + for i, token in enumerate(batch): + if token != image_token: + first_non_image = i + break + + # manually add the bos to the left of the text + batch_tokenized_inputs = [ + batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:] + for batch in batch_tokenized_inputs + ] + if image_inputs: image_input = image_inputs[0] new_image_inputs = {