Remove other tensor creation.

This commit is contained in:
Nicolas Patry 2024-09-06 16:59:11 +02:00
parent d45408e935
commit f952024533
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -1215,13 +1215,6 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1527,6 +1520,8 @@ class FlashCausalLM(Model):
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor input_lengths + prefix_lens_tensor
) )
cuda_graph["prefix_lengths"].zero_()
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
@ -1534,7 +1529,7 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=cuda_graph["input_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=cuda_graph["prefix_lengths"],
state=cuda_graph.get("state"), state=cuda_graph.get("state"),
): ):
# Replay the graph # Replay the graph