diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a6911df8..6e23aa2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -151,13 +151,14 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - # output tensor - attn_output = torch.empty_like(query) - paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) + # output tensor + attn_output = torch.empty_like(query) + + # Prefill if cu_seqlen_prefill is not None: # flash attention flash_attn.attention( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index cc51fe29..ef3777da 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -214,8 +214,6 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - attn_output = torch.empty_like(query) - if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] else: @@ -224,6 +222,10 @@ class MistralAttention(torch.nn.Module): paged_attention.reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) + + # output tensor + attn_output = torch.empty_like(query) + # Prefill if cu_seqlen_prefill is not None: # flash attention @@ -235,6 +237,7 @@ class MistralAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, + window_size_left=self.max_past, ) # Decode else: