From eecca27113538716ff14b05738eb6e8c5f7afd8a Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 17 Jan 2025 11:50:41 -0500 Subject: [PATCH] feat: improve qwen2-vl startup (#2802) * feat: tokenize each request individually and increase warmup image size * feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller * fix: address image resize and rebase changes * feat: update to run qwen2-vl tests * fix: tweak param types --- backends/client/src/lib.rs | 2 +- backends/v2/src/client/mod.rs | 2 +- backends/v3/src/client/mod.rs | 2 +- .../test_flash_qwen2_vl_simple.json | 26 +++ .../models/test_flash_qwen2_vl.py | 161 +++++++++--------- .../models/test_flash_qwen2_vl_warmup.py | 38 +++++ .../text_generation_server/models/__init__.py | 6 + .../custom_modeling/flash_qwen2_modeling.py | 7 +- .../models/custom_modeling/qwen2_vl.py | 5 +- .../models/flash_causal_lm.py | 18 +- .../models/vlm_causal_lm.py | 1 - 11 files changed, 173 insertions(+), 95 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json create mode 100644 integration-tests/models/test_flash_qwen2_vl_warmup.py diff --git a/backends/client/src/lib.rs b/backends/client/src/lib.rs index 45bee10c..fbe2e7e6 100644 --- a/backends/client/src/lib.rs +++ b/backends/client/src/lib.rs @@ -86,6 +86,6 @@ impl ChunksToString for Vec { } } -static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; +static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/backends/v2/src/client/mod.rs b/backends/v2/src/client/mod.rs index fa9d4406..9fe114a2 100644 --- a/backends/v2/src/client/mod.rs +++ b/backends/v2/src/client/mod.rs @@ -63,6 +63,6 @@ impl From for ClientError { } } -static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; +static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index d4ac50c9..ab4311c3 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -62,6 +62,6 @@ impl From for InputChunk { } } -static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; +static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json new file mode 100644 index 00000000..a986510f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl_warmup/test_flash_qwen2_vl_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The correct answer is: blue", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1733445131, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.4.2-dev0-native", + "usage": { + "completion_tokens": 7, + "prompt_tokens": 27, + "total_tokens": 34 + } +} diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 97a533fc..946ab2f1 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -1,81 +1,80 @@ -# Disabled because it's broken. -# import pytest -# -# -# @pytest.fixture(scope="module") -# def flash_qwen2_vl_handle(launcher): -# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: -# yield handle -# -# -# @pytest.fixture(scope="module") -# async def flash_qwen2(flash_qwen2_vl_handle): -# await flash_qwen2_vl_handle.health(300) -# return flash_qwen2_vl_handle.client -# -# -# @pytest.mark.private -# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): -# response = await flash_qwen2.chat( -# max_tokens=100, -# seed=42, -# messages=[ -# { -# "role": "user", -# "content": [ -# { -# "type": "image_url", -# "image_url": { -# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" -# }, -# }, -# {"type": "text", "text": "Describe this image."}, -# ], -# }, -# ], -# ) -# -# assert ( -# response.choices[0].message.content -# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." -# ) -# -# assert response == response_snapshot -# -# -# @pytest.mark.private -# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): -# responses = await flash_qwen2.chat( -# max_tokens=100, -# seed=42, -# messages=[ -# { -# "role": "user", -# "content": [ -# { -# "type": "image_url", -# "image_url": { -# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" -# }, -# }, -# {"type": "text", "text": "Describe this image."}, -# ], -# }, -# ], -# stream=True, -# ) -# -# count = 0 -# generated = "" -# last_response = None -# async for response in responses: -# count += 1 -# generated += response.choices[0].delta.content -# last_response = response -# -# assert ( -# generated -# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." -# ) -# assert count == 58 -# assert last_response == response_snapshot +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_vl_handle(launcher): + with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_vl_handle): + await flash_qwen2_vl_handle.health(300) + return flash_qwen2_vl_handle.client + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + max_tokens=100, + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + ) + + assert ( + response.choices[0].message.content + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + ) + + assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot): + responses = await flash_qwen2.chat( + max_tokens=100, + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert ( + generated + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." + ) + assert count == 58 + assert last_response == response_snapshot diff --git a/integration-tests/models/test_flash_qwen2_vl_warmup.py b/integration-tests/models/test_flash_qwen2_vl_warmup.py new file mode 100644 index 00000000..5be87ee2 --- /dev/null +++ b/integration-tests/models/test_flash_qwen2_vl_warmup.py @@ -0,0 +1,38 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_vl_handle(launcher): + with launcher( + "Qwen/Qwen2-VL-2B-Instruct", + max_input_length=40, + max_batch_prefill_tokens=50, + max_total_tokens=51, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_vl_handle): + await flash_qwen2_vl_handle.health(300) + return flash_qwen2_vl_handle.client + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + max_tokens=20, + seed=42, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the color of the sky?"}, + ], + }, + ], + ) + + assert response.choices[0].message.content == "The correct answer is: blue" + + assert response == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e2d24643..7521fb46 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -29,6 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) from text_generation_server.models.globals import ATTENTION +import text_generation_server.models.globals as globals from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -1217,6 +1218,11 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == QWEN2_VL: + # TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL + logger.warning( + "Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2." + ) + globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS)) return VlmCausalLM( model_id=model_id, model_class=Qwen2VLForConditionalGeneration, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index cc4039b1..01d3bf1a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -138,7 +138,12 @@ class Qwen2Attention(torch.nn.Module): dim=-1, ) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb( + query, + torch.select(kv, dim=1, index=0), + cos[: query.shape[0], ...], + sin[: query.shape[0], ...], + ) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index a8e1e8c1..95cf6a31 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -517,11 +517,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): pixel_values: torch.FloatTensor = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, + pixel_attention_mask: Optional[torch.Tensor] = None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, - image_indices=None, + image_indices: Optional[torch.Tensor] = None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -533,6 +533,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ).squeeze(0) inputs_embeds[input_ids == self.image_token_id] = image_embeds + max_s = max(max_s, inputs_embeds.size(0)) hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d097c54f..f2d27db9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -56,11 +56,13 @@ from text_generation_server.models.globals import ( MEM_POOL, ATTENTION, BLOCK_SIZE, - CUDA_GRAPHS, REQUEST_LOGPROBS, TGI_WIGGLE_ROOM, get_adapter_to_index, ) + +# avoid coping CUDA_GRAPHS value by importing globals as a module +import text_generation_server.models.globals as globals from text_generation_server.layers.attention import KVCache, Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -1635,8 +1637,8 @@ class FlashCausalLM(Model): int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") ] - elif CUDA_GRAPHS is not None: - tuning_sequences = CUDA_GRAPHS + elif globals.CUDA_GRAPHS is not None: + tuning_sequences = globals.CUDA_GRAPHS else: tuning_sequences = [1, 2, 3, 4, 5, 6, 7] @@ -1675,13 +1677,14 @@ class FlashCausalLM(Model): "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) - if CUDA_GRAPHS: + if globals.CUDA_GRAPHS: try: log_master( - logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + logger.info, + f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}", ) # Warmup cuda graphs - for bs in CUDA_GRAPHS: + for bs in globals.CUDA_GRAPHS: synchronize(self.device) free_memory = get_free_memory( self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM @@ -1705,7 +1708,8 @@ class FlashCausalLM(Model): logger.exception("Decode cuda graph warmup failed") else: log_master( - logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + logger.info, + f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).", ) assert max_input_tokens is not None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index db78341d..4d6ea84e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -236,7 +236,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch): w = image.width * 2 h = image.height * 2 image = image.resize((w, h)) - if config.model_type == "llava_next": images.append(image) else: