Fixing non flash tests/imports.

This commit is contained in:
Nicolas Patry 2024-07-01 11:54:34 +00:00
parent 4b1364da92
commit a26e57f9f3
3 changed files with 7 additions and 6 deletions

View File

@ -12,7 +12,6 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
@ -92,6 +92,7 @@ except ImportError as e:
FLASH_ATTENTION = False FLASH_ATTENTION = False
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2) __all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)

View File

@ -387,8 +387,8 @@ class FlashCohereLayer(nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
input_lengths,
slots, slots,
input_lengths,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -401,8 +401,8 @@ class FlashCohereLayer(nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
input_lengths,
slots, slots,
input_lengths,
max_s, max_s,
) )
@ -475,8 +475,8 @@ class FlashCohereModel(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables, block_tables,
input_lengths,
slots, slots,
input_lengths,
max_s, max_s,
) )

View File

@ -1086,7 +1086,7 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
seqlen = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1097,7 +1097,7 @@ class FlashCausalLM(Model):
), ),
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=None,
seqlen=seqlen, input_lengths=input_lengths,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,