This commit is contained in:
System administrator 2024-12-12 15:48:56 +00:00
parent 649cb1f5f1
commit 490ca0ef6a

View File

@ -72,11 +72,14 @@ def _flash_attention_forward_patched(
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
# Correctly reshape the states # Correctly reshape the states
_, num_heads, head_dim = query_states.size() _, _, num_heads, head_dim = query_states.size()
# _, num_kv_heads, _ = key_states.size() _, _, num_kv_heads, _ = key_states.size()
# query_states = query_states.view(-1, num_heads, head_dim) # query_states = query_states.view(-1, num_heads, head_dim)
# key_states = key_states.view(-1, num_kv_heads, head_dim) # key_states = key_states.view(-1, num_kv_heads, head_dim)
# value_states = value_states.view(-1, num_kv_heads, head_dim) # value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states = query_states.squeeze(dim=0)
key_states = key_states.squeeze(dim=0)
value_states = value_states.squeeze(dim=0)
# Take care of updating the cache in-place # Take care of updating the cache in-place
kv_cache.store( kv_cache.store(
@ -316,8 +319,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_k=batch.max_current_length, max_k=batch.max_current_length,
) )
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids[None, ...],
position_ids=position_ids, position_ids=position_ids[None, ...],
past_key_values=None, past_key_values=None,
use_cache=False, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -329,7 +332,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
kv_head_mapping=self.kv_head_mapping, kv_head_mapping=self.kv_head_mapping,
).logits ).logits[0, ...]
print("SUCCESSFUL FORWARD")
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits, None return logits, None