diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index d8c5103f..46657a7a 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -834,7 +834,7 @@ class VlmCausalLM(FlashCausalLM): cache_lengths = [0] * bs if max_bs is None: input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - input_embeds = torch.zeros( + inputs_embeds = torch.zeros( (bs, self.model.config.text_config.hidden_size), device=self.device, dtype=self.dtype, @@ -873,7 +873,7 @@ class VlmCausalLM(FlashCausalLM): "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" ) input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - input_embeds = self.cuda_graphs[max_bs]["input_embeds"][:bs] + inputs_embeds = self.cuda_graphs[max_bs]["inputs_embeds"][:bs] position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] if ATTENTION == "flashinfer": block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] @@ -906,7 +906,7 @@ class VlmCausalLM(FlashCausalLM): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs] = { "input_ids": input_ids, - "input_embeds": input_embeds, + "inputs_embeds": inputs_embeds, "position_ids": position_ids, "kv_cache": self.kv_cache, "block_tables": block_tables, @@ -935,7 +935,7 @@ class VlmCausalLM(FlashCausalLM): ) self.model.forward( input_ids=input_ids, - inputs_embeds=input_embeds, + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, @@ -960,7 +960,7 @@ class VlmCausalLM(FlashCausalLM): ) logits, speculative_logits = self.model.forward( input_ids=input_ids, - inputs_embeds=input_embeds, + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache,