From 490ca0ef6ac0fa5fa34e1b019b62ece2708c7580 Mon Sep 17 00:00:00 2001 From: System administrator Date: Thu, 12 Dec 2024 15:48:56 +0000 Subject: [PATCH] working --- .../models/transformers_flash_causal_lm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index f4f24749..abfaa06e 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -72,11 +72,14 @@ def _flash_attention_forward_patched( kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) # Correctly reshape the states - _, num_heads, head_dim = query_states.size() - # _, num_kv_heads, _ = key_states.size() + _, _, num_heads, head_dim = query_states.size() + _, _, num_kv_heads, _ = key_states.size() # query_states = query_states.view(-1, num_heads, head_dim) # key_states = key_states.view(-1, num_kv_heads, head_dim) # value_states = value_states.view(-1, num_kv_heads, head_dim) + query_states = query_states.squeeze(dim=0) + key_states = key_states.squeeze(dim=0) + value_states = value_states.squeeze(dim=0) # Take care of updating the cache in-place kv_cache.store( @@ -316,8 +319,8 @@ class TransformersFlashCausalLM(FlashCausalLM): max_k=batch.max_current_length, ) logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + input_ids=input_ids[None, ...], + position_ids=position_ids[None, ...], past_key_values=None, use_cache=False, # we use self.kv_cache instead of transformers cache object cu_seqlen_prefill=cu_seqlen_prefill, @@ -329,7 +332,8 @@ class TransformersFlashCausalLM(FlashCausalLM): prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, kv_head_mapping=self.kv_head_mapping, - ).logits + ).logits[0, ...] + print("SUCCESSFUL FORWARD") if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None