diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ce1db7a5..fa3a78f8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -145,13 +145,14 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - # output tensor - attn_output = torch.empty_like(query) - paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) + # output tensor + attn_output = torch.empty_like(query) + + # Prefill if cu_seqlen_prefill is not None: # flash attention flash_attn.attention( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d81d8080..65043dee 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -181,8 +181,6 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - attn_output = torch.empty_like(query) - if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] else: @@ -191,6 +189,10 @@ class MistralAttention(torch.nn.Module): paged_attention.reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) + + # output tensor + attn_output = torch.empty_like(query) + # Prefill if cu_seqlen_prefill is not None: # flash attention @@ -202,6 +204,7 @@ class MistralAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, + window_size_left=self.max_past, ) # Decode else: