mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
fix bug
This commit is contained in:
parent
422bf1f986
commit
cd3c28cfe7
@ -827,7 +827,7 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm" and self.speculate is None or self.speculate == 0:
|
||||||
if (
|
if (
|
||||||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||||
@ -875,7 +875,11 @@ class FlashCausalLM(Model):
|
|||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if (
|
||||||
|
self.speculate is None
|
||||||
|
or self.speculate == 0
|
||||||
|
or self.speculate + 1 <= bs
|
||||||
|
):
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
Loading…
Reference in New Issue
Block a user