diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 288c8ccf..fc50d69a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -945,7 +945,7 @@ class FlashCausalLMBatch(Batch): ) self.cu_seqlen_prefill = torch.nn.functional.pad( torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) - ) + ).to(torch.int32) self.cache_lengths_tensor = torch.tensor( self.cache_lengths, dtype=torch.int32, device=device )