fix: prefer position_ids passed from vlm causal lm and reset ids on batch

This commit is contained in:
David Holtz 2024-10-29 01:13:17 +00:00
parent fb1ae6d24c
commit 77c81a29cb
4 changed files with 22 additions and 17 deletions

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet", "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1730084696, "created": 1730164250,
"id": "", "id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct", "model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.3.2-dev0-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 41, "completion_tokens": 58,
"prompt_tokens": 349, "prompt_tokens": 349,
"total_tokens": 390 "total_tokens": 407
} }
} }

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher): def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
yield handle yield handle
@ -36,13 +36,7 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
assert ( assert (
response.choices[0].message.content response.choices[0].message.content
== "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet" == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
) )
# # TODO: return reference response
# assert (
# response.choices[0].message.content
# == "The image depicts an astronaut with a rabbit's head standing on a rocky, reddish terrain. The astronaut is wearing a space suit with various buttons and"
# )
assert response == response_snapshot assert response == response_snapshot

View File

@ -409,7 +409,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
.item() .item()
) )
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
time_steps, height, width = image_grid_thw[image_index] time_steps, height, width = image_grid_thw[image_index].clone()
height //= self.spatial_merge_size height //= self.spatial_merge_size
width //= self.spatial_merge_size width //= self.spatial_merge_size
@ -487,12 +487,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# apply the visual model to the pixel values if they are provided # apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds inputs_embeds[input_ids == self.image_token_id] = image_embeds
position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds.squeeze(0), inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,

View File

@ -360,6 +360,16 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if self.model.get_position_ids:
if position_ids.shape[0] != 1:
position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw
)
batch.position_ids = position_ids[0, 0, :]
else:
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.