mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
nit
This commit is contained in:
parent
26212b9f35
commit
2f67c53075
@ -834,7 +834,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
cache_lengths = [0] * bs
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
input_embeds = torch.zeros(
|
||||
inputs_embeds = torch.zeros(
|
||||
(bs, self.model.config.text_config.hidden_size),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
@ -873,7 +873,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs]
|
||||
inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||
@ -906,7 +906,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"input_embeds": input_embeds,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
@ -935,7 +935,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=input_embeds,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
@ -960,7 +960,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=input_embeds,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
|
Loading…
Reference in New Issue
Block a user