fix: adjust inputs_embeds passed to language model and debug

This commit is contained in:
drbh 2024-05-10 17:32:14 +00:00 committed by Nicolas Patry
parent 4df1b25ddb
commit 6e8a2110f8
2 changed files with 22 additions and 11 deletions

View File

@ -39,6 +39,9 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
# TODO: used for debugging; to avoid breaking during warmup
count = 0
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
def __init__( def __init__(
@ -103,7 +106,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps) return cls(weight, eps)
# perform the multiplication in full precision and downcast after # perform the multiplication in full precision and downcast after
@ -114,7 +117,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight hidden_states = hidden_states * (self.weight.float() + 1.0)
return hidden_states.to(self.weight.dtype), residual return hidden_states.to(self.weight.dtype), residual
@ -211,6 +214,7 @@ class FlashGemmaAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
global count
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
[ [
@ -221,7 +225,11 @@ class FlashGemmaAttention(torch.nn.Module):
) )
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
if count > 0:
import ipdb
ipdb.set_trace()
# looks good prior to attention
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
@ -256,7 +264,10 @@ class FlashGemmaAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
) )
if count > 0:
import ipdb
ipdb.set_trace()
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -413,6 +424,7 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
global count
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
@ -437,7 +449,7 @@ class FlashGemmaModel(torch.nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
count += 1 # for debugging; to avoid breaking during warmup
return hidden_states return hidden_states

View File

@ -203,7 +203,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
image_features = self.multi_modal_projector(image_outputs.last_hidden_state) image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# TODO: now we scale them? maybe we can do this up or downstream # TODO: now we scale them? maybe we can do this up or downstream
scaled_image_features = image_features / (self.config.hidden_size**0.5) scaled_image_features = image_features / (
self.config.text_config.hidden_size**0.5
)
# mask where image or padding tokens # mask where image or padding tokens
mask = input_ids == self.config.image_token_index | (input_ids == 2) mask = input_ids == self.config.image_token_index | (input_ids == 2)
@ -213,15 +215,12 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
-1, scaled_image_features.shape[-1] -1, scaled_image_features.shape[-1]
) )
if input_ids.size(0) != 3000:
# import ipdb
# ipdb.set_trace()
pass
# NOTE: scale back up since we dont normalize inside the model like transformers # NOTE: scale back up since we dont normalize inside the model like transformers
# TODO: simplify all the rescaling # TODO: simplify all the rescaling
inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5) normalizer = torch.tensor(
self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
)
inputs_embeds = inputs_embeds * normalizer
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,