mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing non flash tests/imports.
This commit is contained in:
parent
4b1364da92
commit
a26e57f9f3
@ -12,7 +12,6 @@ from pathlib import Path
|
|||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.bloom import BLOOMSharded
|
||||||
from text_generation_server.models.mpt import MPTSharded
|
from text_generation_server.models.mpt import MPTSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
@ -92,6 +92,7 @@ except ImportError as e:
|
|||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
__all__.append(FlashCausalLM)
|
||||||
__all__.append(FlashGPT2)
|
__all__.append(FlashGPT2)
|
||||||
__all__.append(FlashNeoXSharded)
|
__all__.append(FlashNeoXSharded)
|
||||||
__all__.append(FlashRWSharded)
|
__all__.append(FlashRWSharded)
|
||||||
|
@ -387,8 +387,8 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -401,8 +401,8 @@ class FlashCohereLayer(nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -475,8 +475,8 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
|
||||||
slots,
|
slots,
|
||||||
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1086,7 +1086,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
seqlen = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
@ -1097,7 +1097,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
seqlen=seqlen,
|
input_lengths=input_lengths,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user