From 372e071135761a3dfbb4231cab6bca965991bc07 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Sun, 27 Oct 2024 06:01:17 +0000 Subject: [PATCH] Fix the issues of tgi-gaudi for v.2.3.1 Signed-off-by: yuanwu --- Dockerfile | 5 ++- backends/v2/src/backend.rs | 13 +++++++- backends/v2/src/client/grpc_client.rs | 2 ++ backends/v2/src/client/sharded_client.rs | 2 ++ backends/v2/src/lib.rs | 3 ++ backends/v2/src/queue.rs | 31 +++++++++++++++--- backends/v3/src/backend.rs | 1 - router/src/lib.rs | 3 ++ .../text_generation_server/models/__init__.py | 24 +++----------- .../models/causal_lm.py | 2 +- .../text_generation_server/models/globals.py | 32 +++++++++++++++++++ .../models/vlm_causal_lm.py | 10 ++---- server/text_generation_server/server.py | 2 -- 13 files changed, 93 insertions(+), 37 deletions(-) diff --git a/Dockerfile b/Dockerfile index b64c6079..aaae495e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,6 @@ COPY benchmark benchmark COPY router router COPY backends backends COPY launcher launcher - RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder @@ -44,6 +43,10 @@ RUN cargo build --profile release-opt # Text Generation Inference base image FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base +ENV ATTENTION=default +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 + # Text Generation Inference base env ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index 086fc6dc..1ab582eb 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -27,6 +27,8 @@ impl BackendV2 { pub(crate) fn new( client: ShardedClient, waiting_served_ratio: f32, + max_input_tokens: u32, + max_total_tokens: u32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, @@ -48,7 +50,16 @@ impl BackendV2 { } else { 16 }; - let queue = Queue::new(requires_padding, block_size, window_size, speculate); + + let queue = Queue::new( + requires_padding, + block_size, + window_size, + speculate, + max_input_tokens, + max_total_tokens, + ); + let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic diff --git a/backends/v2/src/client/grpc_client.rs b/backends/v2/src/client/grpc_client.rs index b4943521..a56b8a54 100644 --- a/backends/v2/src/client/grpc_client.rs +++ b/backends/v2/src/client/grpc_client.rs @@ -109,6 +109,7 @@ impl Client { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let mut n_tokens = 0; @@ -174,6 +175,7 @@ impl Client { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); diff --git a/backends/v2/src/client/sharded_client.rs b/backends/v2/src/client/sharded_client.rs index eccf76d5..238bd773 100644 --- a/backends/v2/src/client/sharded_client.rs +++ b/backends/v2/src/client/sharded_client.rs @@ -105,6 +105,7 @@ impl ShardedClient { max_input_length: u32, max_prefill_tokens: u32, max_total_tokens: u32, + max_batch_total_tokens: u32, max_batch_size: Option, ) -> Result> { let futures: Vec<_> = self @@ -115,6 +116,7 @@ impl ShardedClient { max_input_length, max_prefill_tokens, max_total_tokens, + max_batch_total_tokens, max_batch_size, )) }) diff --git a/backends/v2/src/lib.rs b/backends/v2/src/lib.rs index 85c36931..90f03230 100644 --- a/backends/v2/src/lib.rs +++ b/backends/v2/src/lib.rs @@ -92,6 +92,7 @@ pub async fn connect_backend( max_input_tokens as u32, max_batch_prefill_tokens, max_total_tokens as u32, + max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))), max_batch_size, ) .await @@ -112,6 +113,8 @@ pub async fn connect_backend( let backend = BackendV2::new( sharded_client, waiting_served_ratio, + max_input_tokens as u32, + max_total_tokens as u32, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index bf52900f..5f793d09 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -43,6 +43,8 @@ impl Queue { block_size: u32, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -53,6 +55,8 @@ impl Queue { block_size, window_size, speculate, + max_input_tokens, + max_total_tokens, queue_receiver, )); @@ -103,9 +107,18 @@ async fn queue_task( block_size: u32, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, block_size, window_size, speculate); + let mut state = State::new( + requires_padding, + block_size, + window_size, + speculate, + max_input_tokens, + max_total_tokens, + ); while let Some(cmd) = receiver.recv().await { match cmd { @@ -153,6 +166,12 @@ struct State { /// Speculation amount speculate: u32, + + /// max input tokens + max_input_tokens: u32, + + /// max total tokens, + max_total_tokens: u32, } impl State { @@ -161,6 +180,8 @@ impl State { block_size: u32, window_size: Option, speculate: u32, + max_input_tokens: u32, + max_total_tokens: u32, ) -> Self { Self { entries: VecDeque::with_capacity(128), @@ -170,6 +191,8 @@ impl State { block_size, window_size, speculate, + max_input_tokens, + max_total_tokens, } } @@ -224,7 +247,6 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; @@ -241,8 +263,7 @@ impl State { if self.requires_padding { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens } else { // pad to block size prefill_tokens += ((entry.request.input_length + self.block_size - 1) @@ -251,7 +272,7 @@ impl State { } if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens); } else { let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 122b4909..b1268152 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -156,7 +156,6 @@ pub(crate) async fn batching_task( .await; let mut waiting_tokens = 1; - tracing::error!("Enter cached batch loop"); // We loop until we do not receive any cached batch from the inference server (== until // all requests have met their stopping criteria) while let Some(batch) = cached_batch { diff --git a/router/src/lib.rs b/router/src/lib.rs index 0901bafa..1f6337b3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -23,6 +23,7 @@ pub enum Attention { Paged, FlashDecoding, FlashInfer, + Default, } impl Attention { @@ -31,6 +32,7 @@ impl Attention { Attention::FlashDecoding => 256, Attention::FlashInfer => 1, Attention::Paged => 16, + Attention::Default => 16, } } } @@ -52,6 +54,7 @@ impl std::str::FromStr for Attention { "paged" => Ok(Attention::Paged), "flashdecoding" => Ok(Attention::FlashDecoding), "flashinfer" => Ok(Attention::FlashInfer), + "default" => Ok(Attention::Default), _ => Err(ParseError), } } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index da064d9c..aebff738 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -17,14 +17,14 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.vlm_causal_lm import VlmCausalLM -from text_generation_server.models.mllama_causal_lm import MllamaCausalLM +#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) -from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch -from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, -) +# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch +# from text_generation_server.models.custom_modeling.mllama import ( +# MllamaForConditionalGeneration, +# ) from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, @@ -196,20 +196,6 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "mllama": - return MllamaCausalLM( - model_id=model_id, - model_class=MllamaForConditionalGeneration, - batch_class=MllamaCausalLMBatch, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 39f1f7c9..065b22f2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1215,7 +1215,7 @@ class CausalLM(Model): max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) self.limit_hpu_graph = True try: - for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE): + for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE): batches= [] iters = math.floor(batch_size/max_prefill_batch_size) DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 92c3cf0d..8ef8ec82 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,8 +1,40 @@ import torch import os from typing import Dict, Optional +from loguru import logger +from text_generation_server.utils.log import log_master + +ATTENTION = os.environ["ATTENTION"] +# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" +PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in { + "1", + "true", +} +PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") +_expected = {"paged", "flashdecoding", "flashinfer", "default"} +assert ( + ATTENTION in _expected +), f"Attention is not valid {ATTENTION}, expected {_expected}" +log_master(logger.info, f"Using Attention = {ATTENTION}") + +if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: + raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) +assert TGI_WIGGLE_ROOM > 0 +assert TGI_WIGGLE_ROOM < 1 + +# This is overridden by the cli +BLOCK_SIZE: int +if ATTENTION == "flashdecoding": + BLOCK_SIZE = 256 +elif ATTENTION == "flashinfer": + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 + # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 88734bdc..5c9955c2 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -117,10 +117,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) from loguru import logger - - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}", - ) return "" * num_features elif config.model_type == "paligemma": @@ -373,9 +369,9 @@ class VlmCausalLMBatch(CausalLMBatch): (image_inputs["pixel_attention_mask"], dummy_attention), dim=0 ) if "image_sizes" in image_inputs: - dummy_shape = list(image_inputs['image_sizes'].shape) - dummy_shape[0] = missing_inputs - dummy_sizes = torch.randint(dummy_shape) + dummy_shape = list(list(image_inputs['image_sizes'])[0]) + dummy_shape = missing_inputs*[dummy_shape] + dummy_sizes = torch.IntTensor(dummy_shape) new_image_inputs["image_sizes"] = torch.cat( (image_inputs["image_sizes"], dummy_sizes), dim=0 ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 8d109298..11a7d7a6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -25,7 +25,6 @@ from text_generation_server.utils.adapter import AdapterInfo try: from text_generation_server.models.pali_gemma import PaliGemmaBatch - from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) @@ -35,7 +34,6 @@ try: PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch, - MllamaCausalLMBatch, } except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash.