fix: prefer position_ids passed from vlm causal lm and reset ids on batch

This commit is contained in:
David Holtz 2024-10-29 01:13:17 +00:00
parent fb1ae6d24c
commit 77c81a29cb
4 changed files with 22 additions and 17 deletions

View File

@ -5,7 +5,7 @@
"index": 0,
"logprobs": null,
"message": {
"content": "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet",
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1730084696,
"created": 1730164250,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.2-dev0-native",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 41,
"completion_tokens": 58,
"prompt_tokens": 349,
"total_tokens": 390
"total_tokens": 407
}
}

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
yield handle
@ -36,13 +36,7 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
assert (
response.choices[0].message.content
== "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet"
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
)
# # TODO: return reference response
# assert (
# response.choices[0].message.content
# == "The image depicts an astronaut with a rabbit's head standing on a rocky, reddish terrain. The astronaut is wearing a space suit with various buttons and"
# )
assert response == response_snapshot

View File

@ -409,7 +409,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
.item()
)
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
time_steps, height, width = image_grid_thw[image_index]
time_steps, height, width = image_grid_thw[image_index].clone()
height //= self.spatial_merge_size
width //= self.spatial_merge_size
@ -487,12 +487,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw)
hidden_states = self.text_model(
inputs_embeds=inputs_embeds.squeeze(0),
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,

View File

@ -360,6 +360,16 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if self.model.get_position_ids:
if position_ids.shape[0] != 1:
position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw
)
batch.position_ids = position_ids[0, 0, :]
else:
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.