Fixing non flash decoding llama path.

This commit is contained in:
Nicolas Patry 2024-05-29 12:35:32 +00:00
parent 6aeb5a73a1
commit 7a29e82629

View File

@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -360,7 +361,7 @@ class FlashLlamaModel(torch.nn.Module):
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
if cu_seqlen_prefill is None:
if cu_seqlen_prefill is None and FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=inputs_embeds.device,