mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Remove other tensor creation.
This commit is contained in:
parent
d45408e935
commit
f952024533
@ -1215,13 +1215,6 @@ class FlashCausalLM(Model):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1527,6 +1520,8 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["prefix_lengths"].zero_()
|
||||
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
@ -1534,7 +1529,7 @@ class FlashCausalLM(Model):
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
prefix_lens_tensor=cuda_graph["prefix_lengths"],
|
||||
state=cuda_graph.get("state"),
|
||||
):
|
||||
# Replay the graph
|
||||
|
Loading…
Reference in New Issue
Block a user