mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
feat(server): auto max_batch_total_tokens for flash att models
This commit is contained in:
parent
3b71c38558
commit
b165f8b7b7
@ -184,8 +184,8 @@ struct Args {
|
||||
/// depends on other parameters like if you're using quantization, flash attention
|
||||
/// or the model implementation, text-generation-inference cannot infer this number
|
||||
/// automatically.
|
||||
#[clap(default_value = "16000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
|
||||
/// This setting defines how many tokens can be passed before forcing the waiting
|
||||
/// queries to be put on the batch (if the size of the batch allows for it).
|
||||
@ -428,7 +428,7 @@ fn shard_manager(
|
||||
}
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {rank}");
|
||||
tracing::info!("Starting shard");
|
||||
let mut p = match Command::new("text-generation-server")
|
||||
.args(shard_args)
|
||||
.envs(envs)
|
||||
@ -493,17 +493,17 @@ fn shard_manager(
|
||||
if shutdown.load(Ordering::SeqCst) {
|
||||
p.kill().unwrap();
|
||||
let _ = p.wait();
|
||||
tracing::info!("Shard {rank} terminated");
|
||||
tracing::info!("Shard terminated");
|
||||
return;
|
||||
}
|
||||
|
||||
// Shard is ready
|
||||
if uds.exists() && !ready {
|
||||
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
||||
tracing::info!("Shard ready in {:?}", start_time.elapsed());
|
||||
status_sender.send(ShardStatus::Ready).unwrap();
|
||||
ready = true;
|
||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||
tracing::info!("Waiting for shard {rank} to be ready...");
|
||||
tracing::info!("Waiting for shard to be ready...");
|
||||
wait_time = Instant::now();
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
@ -860,8 +860,6 @@ fn spawn_webserver(
|
||||
args.max_total_tokens.to_string(),
|
||||
"--max-batch-prefill-tokens".to_string(),
|
||||
args.max_batch_prefill_tokens.to_string(),
|
||||
"--max-batch-total-tokens".to_string(),
|
||||
args.max_batch_total_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
@ -878,6 +876,12 @@ fn spawn_webserver(
|
||||
args.model_id,
|
||||
];
|
||||
|
||||
// Model optional max batch total tokens
|
||||
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
router_args.push("--max-batch-total-tokens".to_string());
|
||||
router_args.push(max_batch_total_tokens.to_string());
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = args.revision {
|
||||
router_args.push("--revision".to_string());
|
||||
@ -1036,18 +1040,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
args.max_batch_prefill_tokens, args.max_input_length
|
||||
)));
|
||||
}
|
||||
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, args.max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
@ -1065,6 +1058,21 @@ fn main() -> Result<(), LauncherError> {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
|
@ -198,9 +198,10 @@ message DecodeResponse {
|
||||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
/// Maximum number of tokens that the client will send
|
||||
uint32 max_total_tokens = 2;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message WarmupResponse {}
|
||||
message WarmupResponse {
|
||||
/// Maximum number of tokens supported by the model
|
||||
optional uint32 max_supported_total_tokens = 1;
|
||||
}
|
||||
|
@ -103,8 +103,7 @@ impl Client {
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<()> {
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
|
||||
@ -143,13 +142,9 @@ impl Client {
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
self.stub.warmup(request).await?.into_inner();
|
||||
Ok(())
|
||||
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
|
@ -95,14 +95,11 @@ impl ShardedClient {
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<()> {
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||
})
|
||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
|
@ -37,8 +37,8 @@ struct Args {
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(default_value = "16000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
@ -110,18 +110,22 @@ fn main() -> Result<(), RouterError> {
|
||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||
}
|
||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
// CORS allowed origins
|
||||
// map to go inside the option and then map to parse from String to HeaderValue
|
||||
// Finally, convert to AllowOrigin
|
||||
@ -210,14 +214,29 @@ fn main() -> Result<(), RouterError> {
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_length as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
||||
.await
|
||||
.map_err(RouterError::Warmup)?;
|
||||
.map_err(RouterError::Warmup)?
|
||||
{
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => max_batch_total_tokens.unwrap_or(16000),
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
}
|
||||
tracing::info!(
|
||||
"Model can support up to {max_supported_batch_total_tokens} \
|
||||
max batch total tokens."
|
||||
);
|
||||
max_supported_batch_total_tokens
|
||||
}
|
||||
};
|
||||
tracing::info!("Connected");
|
||||
|
||||
let addr = match hostname.parse() {
|
||||
@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> {
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_supported_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
|
@ -710,14 +710,13 @@ class FlashCausalLM(Model):
|
||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
global CACHE_MANAGER
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
CACHE_MANAGER = CacheManager(
|
||||
# Adds some wiggle room
|
||||
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
||||
batch.blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
@ -727,11 +726,46 @@ class FlashCausalLM(Model):
|
||||
_, batch = self.generate_token(batch)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||
f"prefill tokens. "
|
||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
) from e
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize()
|
||||
peak_memory = torch.cuda.max_memory_allocated(self.device)
|
||||
|
||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
|
||||
# FIXME:
|
||||
# remove wiggle room
|
||||
# when world size > 1, some aggregation ops end up taking more memory than expected
|
||||
safety = 1 - (0.02 * self.world_size)
|
||||
num_blocks = (
|
||||
int((total_gpu_memory * safety - peak_memory) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ batch.blocks
|
||||
)
|
||||
|
||||
del CACHE_MANAGER
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
num_blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
@ -991,7 +1025,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
if stopped:
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
# No need to return a batch if we know that all requests stopped
|
||||
return generations, None
|
||||
|
||||
|
@ -58,8 +58,9 @@ class Model(ABC):
|
||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def warmup(self, batch: B, max_total_tokens: int):
|
||||
def warmup(self, batch: B) -> Optional[int]:
|
||||
self.generate_token(batch)
|
||||
return None
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
|
@ -60,12 +60,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
self.model.warmup(batch, request.max_total_tokens)
|
||||
max_supported_total_tokens = self.model.warmup(batch)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.WarmupResponse()
|
||||
return generate_pb2.WarmupResponse(
|
||||
max_supported_total_tokens=max_supported_total_tokens
|
||||
)
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
@ -73,7 +75,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
if next_batch is not None:
|
||||
self.cache.set(next_batch)
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
@ -102,7 +108,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
batch = batches[0]
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
if next_batch is not None:
|
||||
self.cache.set(next_batch)
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
|
Loading…
Reference in New Issue
Block a user