diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 028dc7ab..e1061f45 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -159,6 +159,107 @@ def _load_gqa(config, prefix: str, weights): ) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class GemmaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.int64, device=x.device + ).float() + / self.dim + ) + ) + + position_ids = position_ids.unsqueeze(0) + + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + + # import ipdb + # ipdb.set_trace() + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class FlashGemmaAttention(torch.nn.Module): def __init__( self, @@ -170,9 +271,16 @@ class FlashGemmaAttention(torch.nn.Module): self.num_heads = config.num_attention_heads self.head_size = config.head_dim - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, + # self._rotary_emb = PositionRotaryEmbedding.static( + # config=config, + # dim=self.head_size, + # base=config.rope_theta, + # device=weights.device, + # ) + + self.rotary_emb = GemmaRotaryEmbedding( dim=self.head_size, + max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, device=weights.device, ) @@ -189,7 +297,21 @@ class FlashGemmaAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + # TODO: prefer this implementation after debugging + # self.query_key_value = load_attention(config, prefix, weights) + + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=False + ) self.o_proj = TensorParallelRowLinear.load( config, @@ -201,6 +323,7 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + self.num_key_value_groups = self.num_heads // self.num_key_value_heads def forward( self, @@ -215,25 +338,54 @@ class FlashGemmaAttention(torch.nn.Module): max_s, ): global count - qkv = self.query_key_value(hidden_states) - query, kv = qkv.split( - [ - self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, - ], - dim=1, + + # TODO: replace with better implementation after debugging + tgt_len, src_len = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = ( + query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2) ) - query = query.view(-1, self.num_heads, self.head_size) - kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2) + value_states = ( + value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2) + ) + + # reshape for flash/paged attention + kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2) + + # cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, None + ) + + # replace the kv with the rotated kv + kv[:, 0] = key_states.squeeze(0).transpose(0, 1) + query = query_states.squeeze(0).transpose(0, 1) + + # TODO: remove after debugging + # ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0") + # ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1)) + + # ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0") + # ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1)) + + # ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0") + # ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1)) + if count > 0: import ipdb - ipdb.set_trace() - # looks good prior to attention - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + # ipdb.set_trace() paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, ) # output tensor @@ -264,10 +416,7 @@ class FlashGemmaAttention(torch.nn.Module): input_lengths, max_s, ) - if count > 0: - import ipdb - ipdb.set_trace() return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -427,14 +576,14 @@ class FlashGemmaModel(torch.nn.Module): global count hidden_states = inputs_embeds - # Get rotary cos and sin for this forward - # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) - residual = None for i, layer in enumerate(self.layers): + # TODO: prefer a single rotary embedding implementation after debugging + cos, sin = self.layers[i].self_attn.rotary_emb( + hidden_states, + position_ids, + ) + hidden_states, residual = layer( hidden_states, residual,