This commit is contained in:
Mohit Sharma 2025-04-22 02:06:57 +05:30 committed by GitHub
parent 26212b9f35
commit 2f67c53075
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -834,7 +834,7 @@ class VlmCausalLM(FlashCausalLM):
cache_lengths = [0] * bs cache_lengths = [0] * bs
if max_bs is None: if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
input_embeds = torch.zeros( inputs_embeds = torch.zeros(
(bs, self.model.config.text_config.hidden_size), (bs, self.model.config.text_config.hidden_size),
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
@ -873,7 +873,7 @@ class VlmCausalLM(FlashCausalLM):
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage" "Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
) )
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs] inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
@ -906,7 +906,7 @@ class VlmCausalLM(FlashCausalLM):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
"input_embeds": input_embeds, "inputs_embeds": inputs_embeds,
"position_ids": position_ids, "position_ids": position_ids,
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
@ -935,7 +935,7 @@ class VlmCausalLM(FlashCausalLM):
) )
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=input_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
@ -960,7 +960,7 @@ class VlmCausalLM(FlashCausalLM):
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=input_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,