mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing non flash decoding llama path.
This commit is contained in:
parent
6aeb5a73a1
commit
7a29e82629
@ -29,6 +29,7 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -360,7 +361,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
if cu_seqlen_prefill is None:
|
||||
if cu_seqlen_prefill is None and FLASH_DECODING:
|
||||
cu_seqlen_q = torch.arange(
|
||||
input_lengths.shape[0] + 1,
|
||||
device=inputs_embeds.device,
|
||||
|
Loading…
Reference in New Issue
Block a user