diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 13885f28..63c97d94 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1412,7 +1412,7 @@ class FlashCausalLM(Model): ).view(-1) prefix_lens_tensor = ( batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) - ).view(-1) + ).reshape(-1) # Add Copy the block tables for all members block_tables = (