mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust inputs_embeds passed to language model and debug
This commit is contained in:
parent
4df1b25ddb
commit
6e8a2110f8
@ -39,6 +39,9 @@ from text_generation_server.layers.layernorm import (
|
|||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: used for debugging; to avoid breaking during warmup
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
|
||||||
class GemmaConfig(PretrainedConfig):
|
class GemmaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -103,7 +106,7 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
class GemmaFastRMSNorm(FastRMSNorm):
|
class GemmaFastRMSNorm(FastRMSNorm):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights, eps=1e-6):
|
def load(cls, prefix, weights, eps=1e-6):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
# perform the multiplication in full precision and downcast after
|
# perform the multiplication in full precision and downcast after
|
||||||
@ -114,7 +117,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
hidden_states = hidden_states * self.weight
|
hidden_states = hidden_states * (self.weight.float() + 1.0)
|
||||||
return hidden_states.to(self.weight.dtype), residual
|
return hidden_states.to(self.weight.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
@ -211,6 +214,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
|
global count
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
@ -221,7 +225,11 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
if count > 0:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
# looks good prior to attention
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
@ -256,7 +264,10 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
if count > 0:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
@ -413,6 +424,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
global count
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
@ -437,7 +449,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
count += 1 # for debugging; to avoid breaking during warmup
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -203,7 +203,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||||
|
|
||||||
# TODO: now we scale them? maybe we can do this up or downstream
|
# TODO: now we scale them? maybe we can do this up or downstream
|
||||||
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
scaled_image_features = image_features / (
|
||||||
|
self.config.text_config.hidden_size**0.5
|
||||||
|
)
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
||||||
@ -213,15 +215,12 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
-1, scaled_image_features.shape[-1]
|
-1, scaled_image_features.shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_ids.size(0) != 3000:
|
|
||||||
# import ipdb
|
|
||||||
|
|
||||||
# ipdb.set_trace()
|
|
||||||
pass
|
|
||||||
|
|
||||||
# NOTE: scale back up since we dont normalize inside the model like transformers
|
# NOTE: scale back up since we dont normalize inside the model like transformers
|
||||||
# TODO: simplify all the rescaling
|
# TODO: simplify all the rescaling
|
||||||
inputs_embeds = inputs_embeds * (self.config.hidden_size**0.5)
|
normalizer = torch.tensor(
|
||||||
|
self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds * normalizer
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
Loading…
Reference in New Issue
Block a user