From 45eb84e4b6c1442af85ccf9b5fc3938a6fe028d9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 2 Dec 2024 19:40:57 +0100 Subject: [PATCH] Adding assertion. --- server/text_generation_server/models/flash_causal_lm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 33888fe4..cf9ac405 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1416,6 +1416,10 @@ class FlashCausalLM(Model): max_current_length=max_s, ) else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] if ATTENTION == "flashinfer":