cast to int32

This commit is contained in:
OlivierDehaene 2024-10-24 19:01:40 +02:00
parent ea66379e3c
commit d1e95ceaff
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -945,7 +945,7 @@ class FlashCausalLMBatch(Batch):
) )
self.cu_seqlen_prefill = torch.nn.functional.pad( self.cu_seqlen_prefill = torch.nn.functional.pad(
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
) ).to(torch.int32)
self.cache_lengths_tensor = torch.tensor( self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device self.cache_lengths, dtype=torch.int32, device=device
) )