mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: remove debug cuda avoid
This commit is contained in:
parent
66f89120b5
commit
de421dc53e
@ -8,6 +8,7 @@ from typing import Optional
|
|||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.bloom import BLOOMSharded
|
||||||
from text_generation_server.models.mpt import MPTSharded
|
from text_generation_server.models.mpt import MPTSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
@ -45,20 +46,6 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
# FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError
|
|
||||||
# if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM
|
|
||||||
HAS_CUDA_GRAPH = False
|
|
||||||
try:
|
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
|
||||||
|
|
||||||
HAS_CUDA_GRAPH = True
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.warning(f"Could not import FlashCausalLM: {e}")
|
|
||||||
|
|
||||||
if HAS_CUDA_GRAPH:
|
|
||||||
__all__.append(FlashCausalLM)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
Loading…
Reference in New Issue
Block a user