mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: return correct shape logits and add streaming test
This commit is contained in:
parent
17de5998e5
commit
e2b394e3a0
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1730416361,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
@ -40,3 +40,45 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
|
||||||
|
responses = await flash_qwen2.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe this image."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
generated = ""
|
||||||
|
last_response = None
|
||||||
|
try:
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
except Exception as e:
|
||||||
|
# handle when the client library raises an exception when it cant parse "[DONE]" as JSON
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert (
|
||||||
|
generated
|
||||||
|
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||||
|
)
|
||||||
|
assert count == 58
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
@ -518,5 +518,5 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, None
|
return logits, speculative_logits
|
||||||
|
@ -364,7 +364,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if self.model.config.model_type == "qwen2_vl":
|
if self.model.config.model_type == "qwen2_vl":
|
||||||
if position_ids.dim() == 1:
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
position_ids = self.model.get_position_ids(
|
position_ids = self.model.get_position_ids(
|
||||||
input_ids, batch.image_grid_thw
|
input_ids, batch.image_grid_thw
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user