From d1e95ceaffe33974b64d89a9466a3ec1fdfaf761 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 24 Oct 2024 19:01:40 +0200 Subject: [PATCH] cast to int32 --- server/text_generation_server/models/flash_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 288c8ccf..fc50d69a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -945,7 +945,7 @@ class FlashCausalLMBatch(Batch): ) self.cu_seqlen_prefill = torch.nn.functional.pad( torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) - ) + ).to(torch.int32) self.cache_lengths_tensor = torch.tensor( self.cache_lengths, dtype=torch.int32, device=device )