mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
working
This commit is contained in:
parent
649cb1f5f1
commit
490ca0ef6a
@ -72,11 +72,14 @@ def _flash_attention_forward_patched(
|
|||||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||||
|
|
||||||
# Correctly reshape the states
|
# Correctly reshape the states
|
||||||
_, num_heads, head_dim = query_states.size()
|
_, _, num_heads, head_dim = query_states.size()
|
||||||
# _, num_kv_heads, _ = key_states.size()
|
_, _, num_kv_heads, _ = key_states.size()
|
||||||
# query_states = query_states.view(-1, num_heads, head_dim)
|
# query_states = query_states.view(-1, num_heads, head_dim)
|
||||||
# key_states = key_states.view(-1, num_kv_heads, head_dim)
|
# key_states = key_states.view(-1, num_kv_heads, head_dim)
|
||||||
# value_states = value_states.view(-1, num_kv_heads, head_dim)
|
# value_states = value_states.view(-1, num_kv_heads, head_dim)
|
||||||
|
query_states = query_states.squeeze(dim=0)
|
||||||
|
key_states = key_states.squeeze(dim=0)
|
||||||
|
value_states = value_states.squeeze(dim=0)
|
||||||
|
|
||||||
# Take care of updating the cache in-place
|
# Take care of updating the cache in-place
|
||||||
kv_cache.store(
|
kv_cache.store(
|
||||||
@ -316,8 +319,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
max_k=batch.max_current_length,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids[None, ...],
|
||||||
position_ids=position_ids,
|
position_ids=position_ids[None, ...],
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -329,7 +332,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
).logits
|
).logits[0, ...]
|
||||||
|
print("SUCCESSFUL FORWARD")
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, None
|
return logits, None
|
||||||
|
Loading…
Reference in New Issue
Block a user