mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
nit
This commit is contained in:
parent
26212b9f35
commit
2f67c53075
@ -834,7 +834,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
cache_lengths = [0] * bs
|
cache_lengths = [0] * bs
|
||||||
if max_bs is None:
|
if max_bs is None:
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
input_embeds = torch.zeros(
|
inputs_embeds = torch.zeros(
|
||||||
(bs, self.model.config.text_config.hidden_size),
|
(bs, self.model.config.text_config.hidden_size),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@ -873,7 +873,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||||
)
|
)
|
||||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||||
input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs]
|
inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
|
||||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||||
@ -906,7 +906,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"input_embeds": input_embeds,
|
"inputs_embeds": inputs_embeds,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"kv_cache": self.kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
@ -935,7 +935,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
inputs_embeds=input_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
@ -960,7 +960,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
inputs_embeds=input_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
|
Loading…
Reference in New Issue
Block a user