mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller
This commit is contained in:
parent
1bcfba305b
commit
45e5c2c266
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The correct answer is: blue",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1733445131,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.4.2-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 7,
|
||||||
|
"prompt_tokens": 27,
|
||||||
|
"total_tokens": 34
|
||||||
|
}
|
||||||
|
}
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
def flash_qwen2_vl_handle(launcher):
|
def flash_qwen2_vl_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"Qwen/Qwen2-VL-2B-Instruct",
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
max_input_tokens=40,
|
max_input_length=40,
|
||||||
max_batch_prefill_tokens=50,
|
max_batch_prefill_tokens=50,
|
||||||
max_total_tokens=51,
|
max_total_tokens=51,
|
||||||
) as handle:
|
) as handle:
|
||||||
|
@ -29,6 +29,7 @@ from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|||||||
BloomForCausalLM,
|
BloomForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
import text_generation_server.models.globals as globals
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||||
@ -1217,6 +1218,11 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == QWEN2_VL:
|
if model_type == QWEN2_VL:
|
||||||
|
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
|
||||||
|
logger.warning(
|
||||||
|
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
|
||||||
|
)
|
||||||
|
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2VLForConditionalGeneration,
|
model_class=Qwen2VLForConditionalGeneration,
|
||||||
|
@ -138,7 +138,12 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(
|
||||||
|
query,
|
||||||
|
torch.select(kv, dim=1, index=0),
|
||||||
|
cos[: query.shape[0], ...],
|
||||||
|
sin[: query.shape[0], ...],
|
||||||
|
)
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
Loading…
Reference in New Issue
Block a user