From 7a29e8262954e7154f128fef070ebd6d097c8475 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 May 2024 12:35:32 +0000 Subject: [PATCH] Fixing non flash decoding llama path. --- .../models/custom_modeling/flash_llama_modeling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index bc71d598..4c724c3e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -360,7 +361,7 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) - if cu_seqlen_prefill is None: + if cu_seqlen_prefill is None and FLASH_DECODING: cu_seqlen_q = torch.arange( input_lengths.shape[0] + 1, device=inputs_embeds.device,