diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json new file mode 100644 index 00000000..f9a414fa --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1730416361, + "id": "", + "model": "Qwen/Qwen2-VL-7B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.4.1-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 8ece4fed..dfbb5907 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -40,3 +40,45 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): ) assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): + responses = await flash_qwen2.chat( + max_tokens=100, + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + try: + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + except Exception as e: + # handle when the client library raises an exception when it cant parse "[DONE]" as JSON + pass + + assert ( + generated + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + ) + assert count == 58 + assert last_response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 7c07a3c3..bd20eea5 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -518,5 +518,5 @@ class Qwen2VLForConditionalGeneration(nn.Module): hidden_states, _ = self.norm(hidden_states) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits, None + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 52b06f8d..aa0fe107 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -364,7 +364,7 @@ class VlmCausalLM(FlashCausalLM): lm_head_indices = batch.prefill_head_indices if self.model.config.model_type == "qwen2_vl": - if position_ids.dim() == 1: + if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( input_ids, batch.image_grid_thw )