Adding assertion.

This commit is contained in:
Nicolas Patry 2024-12-02 19:40:57 +01:00
parent b4c5ca5a58
commit 45eb84e4b6
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -1416,6 +1416,10 @@ class FlashCausalLM(Model):
max_current_length=max_s, max_current_length=max_s,
) )
else: else:
if bs > max_bs:
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":