mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Fix the issues of tgi-gaudi for v.2.3.1
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
7e282b4153
commit
372e071135
@ -13,7 +13,6 @@ COPY benchmark benchmark
|
|||||||
COPY router router
|
COPY router router
|
||||||
COPY backends backends
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
|
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
FROM chef AS builder
|
FROM chef AS builder
|
||||||
@ -44,6 +43,10 @@ RUN cargo build --profile release-opt
|
|||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base
|
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base
|
||||||
|
|
||||||
|
ENV ATTENTION=default
|
||||||
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
@ -27,6 +27,8 @@ impl BackendV2 {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
@ -48,7 +50,16 @@ impl BackendV2 {
|
|||||||
} else {
|
} else {
|
||||||
16
|
16
|
||||||
};
|
};
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
|
||||||
|
let queue = Queue::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
@ -109,6 +109,7 @@ impl Client {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
@ -174,6 +175,7 @@ impl Client {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
@ -105,6 +105,7 @@ impl ShardedClient {
|
|||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
@ -115,6 +116,7 @@ impl ShardedClient {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
|
@ -92,6 +92,7 @@ pub async fn connect_backend(
|
|||||||
max_input_tokens as u32,
|
max_input_tokens as u32,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_total_tokens as u32,
|
max_total_tokens as u32,
|
||||||
|
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -112,6 +113,8 @@ pub async fn connect_backend(
|
|||||||
let backend = BackendV2::new(
|
let backend = BackendV2::new(
|
||||||
sharded_client,
|
sharded_client,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_total_tokens as u32,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
|
@ -43,6 +43,8 @@ impl Queue {
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -53,6 +55,8 @@ impl Queue {
|
|||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -103,9 +107,18 @@ async fn queue_task(
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
let mut state = State::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -153,6 +166,12 @@ struct State {
|
|||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
|
/// max input tokens
|
||||||
|
max_input_tokens: u32,
|
||||||
|
|
||||||
|
/// max total tokens,
|
||||||
|
max_total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
@ -161,6 +180,8 @@ impl State {
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_input_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
@ -170,6 +191,8 @@ impl State {
|
|||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +247,6 @@ impl State {
|
|||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
|
|
||||||
@ -241,8 +263,7 @@ impl State {
|
|||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
|
||||||
} else {
|
} else {
|
||||||
// pad to block size
|
// pad to block size
|
||||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||||
@ -251,7 +272,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
|
||||||
} else {
|
} else {
|
||||||
let max_new_tokens = match self.window_size {
|
let max_new_tokens = match self.window_size {
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
@ -156,7 +156,6 @@ pub(crate) async fn batching_task(
|
|||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
tracing::error!("Enter cached batch loop");
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
// all requests have met their stopping criteria)
|
// all requests have met their stopping criteria)
|
||||||
while let Some(batch) = cached_batch {
|
while let Some(batch) = cached_batch {
|
||||||
|
@ -23,6 +23,7 @@ pub enum Attention {
|
|||||||
Paged,
|
Paged,
|
||||||
FlashDecoding,
|
FlashDecoding,
|
||||||
FlashInfer,
|
FlashInfer,
|
||||||
|
Default,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -31,6 +32,7 @@ impl Attention {
|
|||||||
Attention::FlashDecoding => 256,
|
Attention::FlashDecoding => 256,
|
||||||
Attention::FlashInfer => 1,
|
Attention::FlashInfer => 1,
|
||||||
Attention::Paged => 16,
|
Attention::Paged => 16,
|
||||||
|
Attention::Default => 16,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -52,6 +54,7 @@ impl std::str::FromStr for Attention {
|
|||||||
"paged" => Ok(Attention::Paged),
|
"paged" => Ok(Attention::Paged),
|
||||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||||
"flashinfer" => Ok(Attention::FlashInfer),
|
"flashinfer" => Ok(Attention::FlashInfer),
|
||||||
|
"default" => Ok(Attention::Default),
|
||||||
_ => Err(ParseError),
|
_ => Err(ParseError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,14 +17,14 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.starcoder import StarCoder
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.mllama import (
|
# from text_generation_server.models.custom_modeling.mllama import (
|
||||||
MllamaForConditionalGeneration,
|
# MllamaForConditionalGeneration,
|
||||||
)
|
# )
|
||||||
from text_generation_server.utils.adapter import (
|
from text_generation_server.utils.adapter import (
|
||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
build_layer_weight_lookup,
|
build_layer_weight_lookup,
|
||||||
@ -196,20 +196,6 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mllama":
|
|
||||||
return MllamaCausalLM(
|
|
||||||
model_id=model_id,
|
|
||||||
model_class=MllamaForConditionalGeneration,
|
|
||||||
batch_class=MllamaCausalLMBatch,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
default_dtype=torch.bfloat16,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -1215,7 +1215,7 @@ class CausalLM(Model):
|
|||||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||||
self.limit_hpu_graph = True
|
self.limit_hpu_graph = True
|
||||||
try:
|
try:
|
||||||
for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE):
|
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
|
||||||
batches= []
|
batches= []
|
||||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||||
|
@ -1,8 +1,40 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
|
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
}
|
||||||
|
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
||||||
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
|
_expected = {"paged", "flashdecoding", "flashinfer", "default"}
|
||||||
|
assert (
|
||||||
|
ATTENTION in _expected
|
||||||
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
|
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||||
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||||
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
|
# This is overridden by the cli
|
||||||
|
BLOCK_SIZE: int
|
||||||
|
if ATTENTION == "flashdecoding":
|
||||||
|
BLOCK_SIZE = 256
|
||||||
|
elif ATTENTION == "flashinfer":
|
||||||
|
BLOCK_SIZE = 1
|
||||||
|
else:
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
|
@ -117,10 +117,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
|
||||||
)
|
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
@ -373,9 +369,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
|
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
|
||||||
)
|
)
|
||||||
if "image_sizes" in image_inputs:
|
if "image_sizes" in image_inputs:
|
||||||
dummy_shape = list(image_inputs['image_sizes'].shape)
|
dummy_shape = list(list(image_inputs['image_sizes'])[0])
|
||||||
dummy_shape[0] = missing_inputs
|
dummy_shape = missing_inputs*[dummy_shape]
|
||||||
dummy_sizes = torch.randint(dummy_shape)
|
dummy_sizes = torch.IntTensor(dummy_shape)
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
new_image_inputs["image_sizes"] = torch.cat(
|
||||||
(image_inputs["image_sizes"], dummy_sizes), dim=0
|
(image_inputs["image_sizes"], dummy_sizes), dim=0
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,6 @@ from text_generation_server.utils.adapter import AdapterInfo
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
@ -35,7 +34,6 @@ try:
|
|||||||
PaliGemmaBatch,
|
PaliGemmaBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
MllamaCausalLMBatch,
|
|
||||||
}
|
}
|
||||||
except (ImportError, NotImplementedError):
|
except (ImportError, NotImplementedError):
|
||||||
# These imports can fail on CPU/Non flash.
|
# These imports can fail on CPU/Non flash.
|
||||||
|
Loading…
Reference in New Issue
Block a user