Fixing non flash tests/imports.

This commit is contained in:
Nicolas Patry 2024-07-01 11:54:34 +00:00
parent 4b1364da92
commit a26e57f9f3
3 changed files with 7 additions and 6 deletions

View File

@ -12,7 +12,6 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
@ -92,6 +92,7 @@ except ImportError as e:
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)

View File

@ -387,8 +387,8 @@ class FlashCohereLayer(nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
input_lengths,
slots,
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -401,8 +401,8 @@ class FlashCohereLayer(nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
input_lengths,
slots,
input_lengths,
max_s,
)
@ -475,8 +475,8 @@ class FlashCohereModel(torch.nn.Module):
cu_seqlen_prefill,
kv_cache[i],
block_tables,
input_lengths,
slots,
input_lengths,
max_s,
)

View File

@ -1086,7 +1086,7 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
seqlen = Seqlen(input_lengths=input_lengths)
input_lengths = Seqlen(input_lengths=input_lengths)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -1097,7 +1097,7 @@ class FlashCausalLM(Model):
),
kv_cache=self.kv_cache,
block_tables=None,
seqlen=seqlen,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,