diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c6218a22..65180499 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1520,6 +1520,10 @@ class FlashCausalLM(Model): ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables # XXX: This is working only because block 0 is reserved for the healthcheck # so it doesn't matter if we override it with bogus values.