cast to int32

This commit is contained in:
OlivierDehaene 2024-10-24 19:01:40 +02:00
parent ea66379e3c
commit d1e95ceaff
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -945,7 +945,7 @@ class FlashCausalLMBatch(Batch):
)
self.cu_seqlen_prefill = torch.nn.functional.pad(
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
)
).to(torch.int32)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)