From 2f67c53075857e7dd72a8ef319c3e8e425148d86 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 22 Apr 2025 02:06:57 +0530 Subject: [PATCH] nit --- server/text_generation_server/models/vlm_causal_lm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index d8c5103f..46657a7a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -834,7 +834,7 @@ class VlmCausalLM(FlashCausalLM): cache_lengths = [0] * bs if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - input_embeds = torch.zeros( + inputs_embeds = torch.zeros( (bs, self.model.config.text_config.hidden_size), device=self.device, dtype=self.dtype, @@ -873,7 +873,7 @@ class VlmCausalLM(FlashCausalLM): "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" ) input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs] + inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs] position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] if ATTENTION == "flashinfer": block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] @@ -906,7 +906,7 @@ class VlmCausalLM(FlashCausalLM): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs] = { "input_ids": input_ids, - "input_embeds": input_embeds, + "inputs_embeds": inputs_embeds, "position_ids": position_ids, "kv_cache": self.kv_cache, "block_tables": block_tables, @@ -935,7 +935,7 @@ class VlmCausalLM(FlashCausalLM): ) self.model.forward( input_ids=input_ids, - inputs_embeds=input_embeds, + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, @@ -960,7 +960,7 @@ class VlmCausalLM(FlashCausalLM): ) logits, speculative_logits = self.model.forward( input_ids=input_ids, - inputs_embeds=input_embeds, + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache,