mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: prefer position_ids passed from vlm causal lm and reset ids on batch
This commit is contained in:
parent
fb1ae6d24c
commit
77c81a29cb
@ -5,7 +5,7 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet",
|
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
@ -13,14 +13,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1730084696,
|
"created": 1730164250,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "2.3.2-dev0-native",
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 41,
|
"completion_tokens": 58,
|
||||||
"prompt_tokens": 349,
|
"prompt_tokens": 349,
|
||||||
"total_tokens": 390
|
"total_tokens": 407
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_qwen2_vl_handle(launcher):
|
def flash_qwen2_vl_handle(launcher):
|
||||||
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -36,13 +36,7 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
response.choices[0].message.content
|
response.choices[0].message.content
|
||||||
== "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet"
|
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||||
)
|
)
|
||||||
|
|
||||||
# # TODO: return reference response
|
|
||||||
# assert (
|
|
||||||
# response.choices[0].message.content
|
|
||||||
# == "The image depicts an astronaut with a rabbit's head standing on a rocky, reddish terrain. The astronaut is wearing a space suit with various buttons and"
|
|
||||||
# )
|
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -409,7 +409,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
||||||
time_steps, height, width = image_grid_thw[image_index]
|
time_steps, height, width = image_grid_thw[image_index].clone()
|
||||||
height //= self.spatial_merge_size
|
height //= self.spatial_merge_size
|
||||||
width //= self.spatial_merge_size
|
width //= self.spatial_merge_size
|
||||||
|
|
||||||
@ -487,12 +487,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
image_embeds = self.visual(
|
||||||
|
pixel_values, grid_thw=image_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||||
|
|
||||||
position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds.squeeze(0),
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
@ -360,6 +360,16 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if self.model.get_position_ids:
|
||||||
|
if position_ids.shape[0] != 1:
|
||||||
|
position_ids = self.model.get_position_ids(
|
||||||
|
input_ids.unsqueeze(0), batch.image_grid_thw
|
||||||
|
)
|
||||||
|
batch.position_ids = position_ids[0, 0, :]
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.repeat(3, 1, 1).clone()
|
||||||
|
batch.position_ids = position_ids[0, 0, :]
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
# in a circular buffer mode.
|
# in a circular buffer mode.
|
||||||
|
Loading…
Reference in New Issue
Block a user