Fix the issues of tgi-gaudi for v.2.3.1

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-10-27 06:01:17 +00:00
parent 7e282b4153
commit 372e071135
13 changed files with 93 additions and 37 deletions

View File

@ -13,7 +13,6 @@ COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
@ -44,6 +43,10 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image # Text Generation Inference base image
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base
ENV ATTENTION=default
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \

View File

@ -27,6 +27,8 @@ impl BackendV2 {
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
@ -48,7 +50,16 @@ impl BackendV2 {
} else { } else {
16 16
}; };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let queue = Queue::new(
requires_padding,
block_size,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic

View File

@ -109,6 +109,7 @@ impl Client {
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let mut n_tokens = 0; let mut n_tokens = 0;
@ -174,6 +175,7 @@ impl Client {
max_input_length, max_input_length,
max_prefill_tokens, max_prefill_tokens,
max_total_tokens, max_total_tokens,
max_batch_total_tokens,
}) })
.inject_context(); .inject_context();
let response = self.stub.warmup(request).await?.into_inner(); let response = self.stub.warmup(request).await?.into_inner();

View File

@ -105,6 +105,7 @@ impl ShardedClient {
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: Vec<_> = self
@ -115,6 +116,7 @@ impl ShardedClient {
max_input_length, max_input_length,
max_prefill_tokens, max_prefill_tokens,
max_total_tokens, max_total_tokens,
max_batch_total_tokens,
max_batch_size, max_batch_size,
)) ))
}) })

View File

@ -92,6 +92,7 @@ pub async fn connect_backend(
max_input_tokens as u32, max_input_tokens as u32,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_total_tokens as u32, max_total_tokens as u32,
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
max_batch_size, max_batch_size,
) )
.await .await
@ -112,6 +113,8 @@ pub async fn connect_backend(
let backend = BackendV2::new( let backend = BackendV2::new(
sharded_client, sharded_client,
waiting_served_ratio, waiting_served_ratio,
max_input_tokens as u32,
max_total_tokens as u32,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,

View File

@ -43,6 +43,8 @@ impl Queue {
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -53,6 +55,8 @@ impl Queue {
block_size, block_size,
window_size, window_size,
speculate, speculate,
max_input_tokens,
max_total_tokens,
queue_receiver, queue_receiver,
)); ));
@ -103,9 +107,18 @@ async fn queue_task(
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size, window_size, speculate); let mut state = State::new(
requires_padding,
block_size,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
);
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
@ -153,6 +166,12 @@ struct State {
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
/// max input tokens
max_input_tokens: u32,
/// max total tokens,
max_total_tokens: u32,
} }
impl State { impl State {
@ -161,6 +180,8 @@ impl State {
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
) -> Self { ) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
@ -170,6 +191,8 @@ impl State {
block_size, block_size,
window_size, window_size,
speculate, speculate,
max_input_tokens,
max_total_tokens,
} }
} }
@ -224,7 +247,6 @@ impl State {
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0; let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
@ -241,8 +263,7 @@ impl State {
if self.requires_padding { if self.requires_padding {
// We pad to max input length in the Python shards // We pad to max input length in the Python shards
// We need to take these padding tokens into the equation // We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length); prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
// pad to block size // pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1) prefill_tokens += ((entry.request.input_length + self.block_size - 1)
@ -251,7 +272,7 @@ impl State {
} }
if self.requires_padding { if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens; decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
} else { } else {
let max_new_tokens = match self.window_size { let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens, None => entry.request.stopping_parameters.max_new_tokens,

View File

@ -156,7 +156,6 @@ pub(crate) async fn batching_task(
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
tracing::error!("Enter cached batch loop");
// We loop until we do not receive any cached batch from the inference server (== until // We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria) // all requests have met their stopping criteria)
while let Some(batch) = cached_batch { while let Some(batch) = cached_batch {

View File

@ -23,6 +23,7 @@ pub enum Attention {
Paged, Paged,
FlashDecoding, FlashDecoding,
FlashInfer, FlashInfer,
Default,
} }
impl Attention { impl Attention {
@ -31,6 +32,7 @@ impl Attention {
Attention::FlashDecoding => 256, Attention::FlashDecoding => 256,
Attention::FlashInfer => 1, Attention::FlashInfer => 1,
Attention::Paged => 16, Attention::Paged => 16,
Attention::Default => 16,
} }
} }
} }
@ -52,6 +54,7 @@ impl std::str::FromStr for Attention {
"paged" => Ok(Attention::Paged), "paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding), "flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer), "flashinfer" => Ok(Attention::FlashInfer),
"default" => Ok(Attention::Default),
_ => Err(ParseError), _ => Err(ParseError),
} }
} }

View File

@ -17,14 +17,14 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM #from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration, LlavaNextForConditionalGeneration,
) )
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch # from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import ( # from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration, # MllamaForConditionalGeneration,
) # )
from text_generation_server.utils.adapter import ( from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
build_layer_weight_lookup, build_layer_weight_lookup,
@ -196,20 +196,6 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "mllama":
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
model_id, model_id,

View File

@ -1215,7 +1215,7 @@ class CausalLM(Model):
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
self.limit_hpu_graph = True self.limit_hpu_graph = True
try: try:
for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE): for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
batches= [] batches= []
iters = math.floor(batch_size/max_prefill_batch_size) iters = math.floor(batch_size/max_prefill_batch_size)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)

View File

@ -1,8 +1,40 @@
import torch import torch
import os import os
from typing import Dict, Optional from typing import Dict, Optional
from loguru import logger
from text_generation_server.utils.log import log_master
ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer", "default"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
# This is overridden by the cli # This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:

View File

@ -117,10 +117,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger from loguru import logger
logger.info(
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features return "<image>" * num_features
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
@ -373,9 +369,9 @@ class VlmCausalLMBatch(CausalLMBatch):
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0 (image_inputs["pixel_attention_mask"], dummy_attention), dim=0
) )
if "image_sizes" in image_inputs: if "image_sizes" in image_inputs:
dummy_shape = list(image_inputs['image_sizes'].shape) dummy_shape = list(list(image_inputs['image_sizes'])[0])
dummy_shape[0] = missing_inputs dummy_shape = missing_inputs*[dummy_shape]
dummy_sizes = torch.randint(dummy_shape) dummy_sizes = torch.IntTensor(dummy_shape)
new_image_inputs["image_sizes"] = torch.cat( new_image_inputs["image_sizes"] = torch.cat(
(image_inputs["image_sizes"], dummy_sizes), dim=0 (image_inputs["image_sizes"], dummy_sizes), dim=0
) )

View File

@ -25,7 +25,6 @@ from text_generation_server.utils.adapter import AdapterInfo
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch, VlmCausalLMBatch,
) )
@ -35,7 +34,6 @@ try:
PaliGemmaBatch, PaliGemmaBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
MllamaCausalLMBatch,
} }
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash. # These imports can fail on CPU/Non flash.