feat: improve qwen2-vl startup (#2802)

* feat: tokenize each request individually and increase warmup image size

* feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller

* fix: address image resize and rebase changes

* feat: update to run qwen2-vl tests

* fix: tweak param types
This commit is contained in:
drbh 2025-01-17 11:50:41 -05:00 committed by GitHub
parent 6e982f43a1
commit eecca27113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 173 additions and 95 deletions

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The correct answer is: blue",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1733445131,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 27,
"total_tokens": 34
}
}

View File

@ -1,81 +1,80 @@
# Disabled because it's broken. import pytest
# import pytest
#
# @pytest.fixture(scope="module")
# @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher):
# def flash_qwen2_vl_handle(launcher): with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: yield handle
# yield handle
#
# @pytest.fixture(scope="module")
# @pytest.fixture(scope="module") async def flash_qwen2(flash_qwen2_vl_handle):
# async def flash_qwen2(flash_qwen2_vl_handle): await flash_qwen2_vl_handle.health(300)
# await flash_qwen2_vl_handle.health(300) return flash_qwen2_vl_handle.client
# return flash_qwen2_vl_handle.client
#
# @pytest.mark.private
# @pytest.mark.private async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): response = await flash_qwen2.chat(
# response = await flash_qwen2.chat( max_tokens=100,
# max_tokens=100, seed=42,
# seed=42, messages=[
# messages=[ {
# { "role": "user",
# "role": "user", "content": [
# "content": [ {
# { "type": "image_url",
# "type": "image_url", "image_url": {
# "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" },
# }, },
# }, {"type": "text", "text": "Describe this image."},
# {"type": "text", "text": "Describe this image."}, ],
# ], },
# }, ],
# ], )
# )
# assert (
# assert ( response.choices[0].message.content
# response.choices[0].message.content == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." )
# )
# assert response == response_snapshot
# assert response == response_snapshot
#
# @pytest.mark.private
# @pytest.mark.private async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): responses = await flash_qwen2.chat(
# responses = await flash_qwen2.chat( max_tokens=100,
# max_tokens=100, seed=42,
# seed=42, messages=[
# messages=[ {
# { "role": "user",
# "role": "user", "content": [
# "content": [ {
# { "type": "image_url",
# "type": "image_url", "image_url": {
# "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" },
# }, },
# }, {"type": "text", "text": "Describe this image."},
# {"type": "text", "text": "Describe this image."}, ],
# ], },
# }, ],
# ], stream=True,
# stream=True, )
# )
# count = 0
# count = 0 generated = ""
# generated = "" last_response = None
# last_response = None async for response in responses:
# async for response in responses: count += 1
# count += 1 generated += response.choices[0].delta.content
# generated += response.choices[0].delta.content last_response = response
# last_response = response
# assert (
# assert ( generated
# generated == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." )
# ) assert count == 58
# assert count == 58 assert last_response == response_snapshot
# assert last_response == response_snapshot

View File

@ -0,0 +1,38 @@
import pytest
@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-2B-Instruct",
max_input_length=40,
max_batch_prefill_tokens=50,
max_total_tokens=51,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_qwen2(flash_qwen2_vl_handle):
await flash_qwen2_vl_handle.health(300)
return flash_qwen2_vl_handle.client
@pytest.mark.private
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
max_tokens=20,
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the color of the sky?"},
],
},
],
)
assert response.choices[0].message.content == "The correct answer is: blue"
assert response == response_snapshot

View File

@ -29,6 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM, BloomForCausalLM,
) )
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
import text_generation_server.models.globals as globals
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import ( from text_generation_server.models.custom_modeling.neox_modeling import (
@ -1217,6 +1218,11 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL: if model_type == QWEN2_VL:
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
logger.warning(
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
)
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2VLForConditionalGeneration, model_class=Qwen2VLForConditionalGeneration,

View File

@ -138,7 +138,12 @@ class Qwen2Attention(torch.nn.Module):
dim=-1, dim=-1,
) )
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(
query,
torch.select(kv, dim=1, index=0),
cos[: query.shape[0], ...],
sin[: query.shape[0], ...],
)
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices] kv_to_cache = kv[prefill_cache_indices]

View File

@ -517,11 +517,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None, pixel_attention_mask: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@ -533,6 +533,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
).squeeze(0) ).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds inputs_embeds[input_ids == self.image_token_id] = image_embeds
max_s = max(max_s, inputs_embeds.size(0))
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,

View File

@ -56,11 +56,13 @@ from text_generation_server.models.globals import (
MEM_POOL, MEM_POOL,
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS,
REQUEST_LOGPROBS, REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
# avoid coping CUDA_GRAPHS value by importing globals as a module
import text_generation_server.models.globals as globals
from text_generation_server.layers.attention import KVCache, Seqlen from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
@ -1635,8 +1637,8 @@ class FlashCausalLM(Model):
int(val) int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
] ]
elif CUDA_GRAPHS is not None: elif globals.CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS tuning_sequences = globals.CUDA_GRAPHS
else: else:
tuning_sequences = [1, 2, 3, 4, 5, 6, 7] tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
@ -1675,13 +1677,14 @@ class FlashCausalLM(Model):
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
) )
if CUDA_GRAPHS: if globals.CUDA_GRAPHS:
try: try:
log_master( log_master(
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" logger.info,
f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}",
) )
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in globals.CUDA_GRAPHS:
synchronize(self.device) synchronize(self.device)
free_memory = get_free_memory( free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
@ -1705,7 +1708,8 @@ class FlashCausalLM(Model):
logger.exception("Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
log_master( log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." logger.info,
f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).",
) )
assert max_input_tokens is not None assert max_input_tokens is not None

View File

@ -236,7 +236,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
w = image.width * 2 w = image.width * 2
h = image.height * 2 h = image.height * 2
image = image.resize((w, h)) image = image.resize((w, h))
if config.model_type == "llava_next": if config.model_type == "llava_next":
images.append(image) images.append(image)
else: else: