From 6e8a2110f80cd09a2a03d2746e94c163f1b81a44 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 10 May 2024 17:32:14 +0000 Subject: [PATCH] fix: adjust inputs_embeds passed to language model and debug --- .../custom_modeling/flash_gemma_modeling.py | 18 +++++++++++++++--- .../flash_pali_gemma_modeling.py | 15 +++++++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 3d700610..028dc7ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -39,6 +39,9 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +# TODO: used for debugging; to avoid breaking during warmup +count = 0 + class GemmaConfig(PretrainedConfig): def __init__( @@ -103,7 +106,7 @@ class GemmaConfig(PretrainedConfig): class GemmaFastRMSNorm(FastRMSNorm): @classmethod def load(cls, prefix, weights, eps=1e-6): - weight = weights.get_tensor(f"{prefix}.weight") + 1 + weight = weights.get_tensor(f"{prefix}.weight") return cls(weight, eps) # perform the multiplication in full precision and downcast after @@ -114,7 +117,7 @@ class GemmaFastRMSNorm(FastRMSNorm): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = hidden_states * self.weight + hidden_states = hidden_states * (self.weight.float() + 1.0) return hidden_states.to(self.weight.dtype), residual @@ -211,6 +214,7 @@ class FlashGemmaAttention(torch.nn.Module): input_lengths, max_s, ): + global count qkv = self.query_key_value(hidden_states) query, kv = qkv.split( [ @@ -221,7 +225,11 @@ class FlashGemmaAttention(torch.nn.Module): ) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + if count > 0: + import ipdb + ipdb.set_trace() + # looks good prior to attention self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( @@ -256,7 +264,10 @@ class FlashGemmaAttention(torch.nn.Module): input_lengths, max_s, ) + if count > 0: + import ipdb + ipdb.set_trace() return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -413,6 +424,7 @@ class FlashGemmaModel(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, ) -> torch.Tensor: + global count hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -437,7 +449,7 @@ class FlashGemmaModel(torch.nn.Module): ) hidden_states, _ = self.norm(hidden_states, residual) - + count += 1 # for debugging; to avoid breaking during warmup return hidden_states diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index ffafdd9f..23d4ae22 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -203,7 +203,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): image_features = self.multi_modal_projector(image_outputs.last_hidden_state) # TODO: now we scale them? maybe we can do this up or downstream - scaled_image_features = image_features / (self.config.hidden_size**0.5) + scaled_image_features = image_features / ( + self.config.text_config.hidden_size**0.5 + ) # mask where image or padding tokens mask = input_ids == self.config.image_token_index | (input_ids == 2) @@ -213,15 +215,12 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): -1, scaled_image_features.shape[-1] ) - if input_ids.size(0) != 3000: - # import ipdb - - # ipdb.set_trace() - pass - # NOTE: scale back up since we dont normalize inside the model like transformers # TODO: simplify all the rescaling - inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5) + normalizer = torch.tensor( + self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds * normalizer hidden_states = self.language_model.model( inputs_embeds=inputs_embeds,