mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 04:52:07 +00:00
Fix the issues of tgi-gaudi for v.2.3.1
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
7e282b4153
commit
372e071135
@ -13,7 +13,6 @@ COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
@ -44,6 +43,10 @@ RUN cargo build --profile release-opt
|
||||
# Text Generation Inference base image
|
||||
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base
|
||||
|
||||
ENV ATTENTION=default
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
|
@ -27,6 +27,8 @@ impl BackendV2 {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
@ -48,7 +50,16 @@ impl BackendV2 {
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
|
@ -109,6 +109,7 @@ impl Client {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
@ -174,6 +175,7 @@ impl Client {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
|
@ -105,6 +105,7 @@ impl ShardedClient {
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
@ -115,6 +116,7 @@ impl ShardedClient {
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
|
@ -92,6 +92,7 @@ pub async fn connect_backend(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
@ -112,6 +113,8 @@ pub async fn connect_backend(
|
||||
let backend = BackendV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_input_tokens as u32,
|
||||
max_total_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
|
@ -43,6 +43,8 @@ impl Queue {
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -53,6 +55,8 @@ impl Queue {
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
@ -103,9 +107,18 @@ async fn queue_task(
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
@ -153,6 +166,12 @@ struct State {
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// max input tokens
|
||||
max_input_tokens: u32,
|
||||
|
||||
/// max total tokens,
|
||||
max_total_tokens: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
@ -161,6 +180,8 @@ impl State {
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_input_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
@ -170,6 +191,8 @@ impl State {
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
@ -224,7 +247,6 @@ impl State {
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
|
||||
@ -241,8 +263,7 @@ impl State {
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
@ -251,7 +272,7 @@ impl State {
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
|
@ -156,7 +156,6 @@ pub(crate) async fn batching_task(
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
tracing::error!("Enter cached batch loop");
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
|
@ -23,6 +23,7 @@ pub enum Attention {
|
||||
Paged,
|
||||
FlashDecoding,
|
||||
FlashInfer,
|
||||
Default,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -31,6 +32,7 @@ impl Attention {
|
||||
Attention::FlashDecoding => 256,
|
||||
Attention::FlashInfer => 1,
|
||||
Attention::Paged => 16,
|
||||
Attention::Default => 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -52,6 +54,7 @@ impl std::str::FromStr for Attention {
|
||||
"paged" => Ok(Attention::Paged),
|
||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||
"flashinfer" => Ok(Attention::FlashInfer),
|
||||
"default" => Ok(Attention::Default),
|
||||
_ => Err(ParseError),
|
||||
}
|
||||
}
|
||||
|
@ -17,14 +17,14 @@ from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
# from text_generation_server.models.custom_modeling.mllama import (
|
||||
# MllamaForConditionalGeneration,
|
||||
# )
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
build_layer_weight_lookup,
|
||||
@ -196,20 +196,6 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "mllama":
|
||||
return MllamaCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
batch_class=MllamaCausalLMBatch,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
|
@ -1215,7 +1215,7 @@ class CausalLM(Model):
|
||||
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
|
||||
self.limit_hpu_graph = True
|
||||
try:
|
||||
for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE):
|
||||
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
|
||||
batches= []
|
||||
iters = math.floor(batch_size/max_prefill_batch_size)
|
||||
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
|
||||
|
@ -1,8 +1,40 @@
|
||||
import torch
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
ATTENTION = os.environ["ATTENTION"]
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer", "default"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
# This is overridden by the cli
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
|
@ -117,10 +117,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||
)
|
||||
return "<image>" * num_features
|
||||
|
||||
elif config.model_type == "paligemma":
|
||||
@ -373,9 +369,9 @@ class VlmCausalLMBatch(CausalLMBatch):
|
||||
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
|
||||
)
|
||||
if "image_sizes" in image_inputs:
|
||||
dummy_shape = list(image_inputs['image_sizes'].shape)
|
||||
dummy_shape[0] = missing_inputs
|
||||
dummy_sizes = torch.randint(dummy_shape)
|
||||
dummy_shape = list(list(image_inputs['image_sizes'])[0])
|
||||
dummy_shape = missing_inputs*[dummy_shape]
|
||||
dummy_sizes = torch.IntTensor(dummy_shape)
|
||||
new_image_inputs["image_sizes"] = torch.cat(
|
||||
(image_inputs["image_sizes"], dummy_sizes), dim=0
|
||||
)
|
||||
|
@ -25,7 +25,6 @@ from text_generation_server.utils.adapter import AdapterInfo
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
@ -35,7 +34,6 @@ try:
|
||||
PaliGemmaBatch,
|
||||
VlmCausalLMBatch,
|
||||
IdeficsCausalLMBatch,
|
||||
MllamaCausalLMBatch,
|
||||
}
|
||||
except (ImportError, NotImplementedError):
|
||||
# These imports can fail on CPU/Non flash.
|
||||
|
Loading…
Reference in New Issue
Block a user