mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
feat: improve qwen2-vl startup (#2802)
* feat: tokenize each request individually and increase warmup image size * feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller * fix: address image resize and rebase changes * feat: update to run qwen2-vl tests * fix: tweak param types
This commit is contained in:
parent
6e982f43a1
commit
eecca27113
@ -86,6 +86,6 @@ impl ChunksToString for Vec<InputChunk> {
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
@ -63,6 +63,6 @@ impl From<transport::Error> for ClientError {
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
@ -62,6 +62,6 @@ impl From<Chunk> for InputChunk {
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "The correct answer is: blue",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1733445131,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.4.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 7,
|
||||
"prompt_tokens": 27,
|
||||
"total_tokens": 34
|
||||
}
|
||||
}
|
@ -1,81 +1,80 @@
|
||||
# Disabled because it's broken.
|
||||
# import pytest
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="module")
|
||||
# def flash_qwen2_vl_handle(launcher):
|
||||
# with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||
# yield handle
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="module")
|
||||
# async def flash_qwen2(flash_qwen2_vl_handle):
|
||||
# await flash_qwen2_vl_handle.health(300)
|
||||
# return flash_qwen2_vl_handle.client
|
||||
#
|
||||
#
|
||||
# @pytest.mark.private
|
||||
# async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
||||
# response = await flash_qwen2.chat(
|
||||
# max_tokens=100,
|
||||
# seed=42,
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
# },
|
||||
# },
|
||||
# {"type": "text", "text": "Describe this image."},
|
||||
# ],
|
||||
# },
|
||||
# ],
|
||||
# )
|
||||
#
|
||||
# assert (
|
||||
# response.choices[0].message.content
|
||||
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||
# )
|
||||
#
|
||||
# assert response == response_snapshot
|
||||
#
|
||||
#
|
||||
# @pytest.mark.private
|
||||
# async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
|
||||
# responses = await flash_qwen2.chat(
|
||||
# max_tokens=100,
|
||||
# seed=42,
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
# },
|
||||
# },
|
||||
# {"type": "text", "text": "Describe this image."},
|
||||
# ],
|
||||
# },
|
||||
# ],
|
||||
# stream=True,
|
||||
# )
|
||||
#
|
||||
# count = 0
|
||||
# generated = ""
|
||||
# last_response = None
|
||||
# async for response in responses:
|
||||
# count += 1
|
||||
# generated += response.choices[0].delta.content
|
||||
# last_response = response
|
||||
#
|
||||
# assert (
|
||||
# generated
|
||||
# == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||
# )
|
||||
# assert count == 58
|
||||
# assert last_response == response_snapshot
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_qwen2_vl_handle(launcher):
|
||||
with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_qwen2(flash_qwen2_vl_handle):
|
||||
await flash_qwen2_vl_handle.health(300)
|
||||
return flash_qwen2_vl_handle.client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
||||
response = await flash_qwen2.chat(
|
||||
max_tokens=100,
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
|
||||
responses = await flash_qwen2.chat(
|
||||
max_tokens=100,
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
|
||||
assert (
|
||||
generated
|
||||
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
|
||||
)
|
||||
assert count == 58
|
||||
assert last_response == response_snapshot
|
||||
|
38
integration-tests/models/test_flash_qwen2_vl_warmup.py
Normal file
38
integration-tests/models/test_flash_qwen2_vl_warmup.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_qwen2_vl_handle(launcher):
|
||||
with launcher(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
max_input_length=40,
|
||||
max_batch_prefill_tokens=50,
|
||||
max_total_tokens=51,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_qwen2(flash_qwen2_vl_handle):
|
||||
await flash_qwen2_vl_handle.health(300)
|
||||
return flash_qwen2_vl_handle.client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
|
||||
response = await flash_qwen2.chat(
|
||||
max_tokens=20,
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the color of the sky?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert response.choices[0].message.content == "The correct answer is: blue"
|
||||
|
||||
assert response == response_snapshot
|
@ -29,6 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import text_generation_server.models.globals as globals
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
@ -1217,6 +1218,11 @@ def get_model(
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == QWEN2_VL:
|
||||
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
|
||||
logger.warning(
|
||||
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
|
||||
)
|
||||
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2VLForConditionalGeneration,
|
||||
|
@ -138,7 +138,12 @@ class Qwen2Attention(torch.nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
cos[: query.shape[0], ...],
|
||||
sin[: query.shape[0], ...],
|
||||
)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
|
@ -517,11 +517,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
pixel_attention_mask: Optional[torch.Tensor] = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
image_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
@ -533,6 +533,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
).squeeze(0)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
max_s = max(max_s, inputs_embeds.size(0))
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
|
@ -56,11 +56,13 @@ from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
REQUEST_LOGPROBS,
|
||||
TGI_WIGGLE_ROOM,
|
||||
get_adapter_to_index,
|
||||
)
|
||||
|
||||
# avoid coping CUDA_GRAPHS value by importing globals as a module
|
||||
import text_generation_server.models.globals as globals
|
||||
from text_generation_server.layers.attention import KVCache, Seqlen
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
@ -1635,8 +1637,8 @@ class FlashCausalLM(Model):
|
||||
int(val)
|
||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||
]
|
||||
elif CUDA_GRAPHS is not None:
|
||||
tuning_sequences = CUDA_GRAPHS
|
||||
elif globals.CUDA_GRAPHS is not None:
|
||||
tuning_sequences = globals.CUDA_GRAPHS
|
||||
else:
|
||||
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
@ -1675,13 +1677,14 @@ class FlashCausalLM(Model):
|
||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
|
||||
)
|
||||
|
||||
if CUDA_GRAPHS:
|
||||
if globals.CUDA_GRAPHS:
|
||||
try:
|
||||
log_master(
|
||||
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
||||
logger.info,
|
||||
f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}",
|
||||
)
|
||||
# Warmup cuda graphs
|
||||
for bs in CUDA_GRAPHS:
|
||||
for bs in globals.CUDA_GRAPHS:
|
||||
synchronize(self.device)
|
||||
free_memory = get_free_memory(
|
||||
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
|
||||
@ -1705,7 +1708,8 @@ class FlashCausalLM(Model):
|
||||
logger.exception("Decode cuda graph warmup failed")
|
||||
else:
|
||||
log_master(
|
||||
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||
logger.info,
|
||||
f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).",
|
||||
)
|
||||
|
||||
assert max_input_tokens is not None
|
||||
|
@ -236,7 +236,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
w = image.width * 2
|
||||
h = image.height * 2
|
||||
image = image.resize((w, h))
|
||||
|
||||
if config.model_type == "llava_next":
|
||||
images.append(image)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user