From 01e4442ef6061fa3fa32cd4852d018116c6b07fe Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 24 May 2024 14:18:00 +0000 Subject: [PATCH] REvert changes in modeling. --- .../models/custom_modeling/flash_llama_modeling.py | 7 ++++--- .../models/custom_modeling/flash_mistral_modeling.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a6911df8..6e23aa2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -151,13 +151,14 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - # output tensor - attn_output = torch.empty_like(query) - paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) + # output tensor + attn_output = torch.empty_like(query) + + # Prefill if cu_seqlen_prefill is not None: # flash attention flash_attn.attention( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index cc51fe29..ef3777da 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -214,8 +214,6 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - attn_output = torch.empty_like(query) - if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] else: @@ -224,6 +222,10 @@ class MistralAttention(torch.nn.Module): paged_attention.reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) + + # output tensor + attn_output = torch.empty_like(query) + # Prefill if cu_seqlen_prefill is not None: # flash attention @@ -235,6 +237,7 @@ class MistralAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, + window_size_left=self.max_past, ) # Decode else: