2024-04-08 16:06:21 +00:00
|
|
|
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Single shard Client
|
2023-12-11 11:46:30 +00:00
|
|
|
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
|
|
|
use crate::pb::generate::v2::*;
|
2022-10-08 10:30:12 +00:00
|
|
|
use crate::Result;
|
2024-02-28 09:40:13 +00:00
|
|
|
use std::env;
|
|
|
|
use rand::{distributions::Uniform, Rng};
|
2023-02-13 12:02:45 +00:00
|
|
|
use grpc_metadata::InjectTelemetryContext;
|
2024-02-28 09:40:13 +00:00
|
|
|
use std::cmp;
|
2022-10-08 10:30:12 +00:00
|
|
|
use tonic::transport::{Channel, Uri};
|
2023-02-13 12:02:45 +00:00
|
|
|
use tracing::instrument;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Text Generation Inference gRPC client
|
2023-04-26 18:23:54 +00:00
|
|
|
#[derive(Debug, Clone)]
|
2022-10-08 10:30:12 +00:00
|
|
|
pub struct Client {
|
2022-10-17 12:59:00 +00:00
|
|
|
stub: TextGenerationServiceClient<Channel>,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Client {
|
2022-10-17 12:59:00 +00:00
|
|
|
/// Returns a client connected to the given url
|
|
|
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
|
|
|
let channel = Channel::builder(uri).connect().await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
Ok(Self {
|
|
|
|
stub: TextGenerationServiceClient::new(channel),
|
|
|
|
})
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
/// Returns a client connected to the given unix socket
|
|
|
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
2022-10-11 14:50:54 +00:00
|
|
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
2022-10-08 10:30:12 +00:00
|
|
|
.unwrap()
|
|
|
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
|
|
|
tokio::net::UnixStream::connect(path.clone())
|
|
|
|
}))
|
2022-10-17 12:59:00 +00:00
|
|
|
.await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
Ok(Self {
|
|
|
|
stub: TextGenerationServiceClient::new(channel),
|
|
|
|
})
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Returns a list of uris or unix sockets of all shards
|
2022-10-08 10:30:12 +00:00
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
2023-02-13 12:02:45 +00:00
|
|
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
|
|
|
let response = self.stub.service_discovery(request).await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
let urls = response
|
|
|
|
.into_inner()
|
|
|
|
.urls
|
|
|
|
.into_iter()
|
2022-10-18 13:19:03 +00:00
|
|
|
// Remove unix socket prefix
|
2022-10-08 10:30:12 +00:00
|
|
|
.map(|url| match url.strip_prefix("unix://") {
|
|
|
|
None => url,
|
|
|
|
Some(stripped_url) => stripped_url.to_string(),
|
|
|
|
})
|
|
|
|
.collect();
|
|
|
|
Ok(urls)
|
|
|
|
}
|
|
|
|
|
2023-04-21 13:36:29 +00:00
|
|
|
/// Get model info
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
|
|
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
|
|
|
let response = self.stub.info(request).await?.into_inner();
|
|
|
|
Ok(response)
|
|
|
|
}
|
|
|
|
|
2023-04-26 18:23:54 +00:00
|
|
|
/// Get model health
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
|
|
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
|
|
|
let response = self.stub.health(request).await?.into_inner();
|
|
|
|
Ok(response)
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Clear the past generations cache
|
2022-10-08 10:30:12 +00:00
|
|
|
#[instrument(skip(self))]
|
2023-03-28 09:29:35 +00:00
|
|
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
|
|
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
2023-02-13 12:02:45 +00:00
|
|
|
self.stub.clear_cache(request).await?;
|
2022-10-08 10:30:12 +00:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
2023-04-24 15:59:00 +00:00
|
|
|
/// Filter a cached batch
|
|
|
|
#[instrument(skip(self))]
|
|
|
|
pub async fn filter_batch(
|
|
|
|
&mut self,
|
|
|
|
batch_id: u64,
|
2023-05-24 17:19:57 +00:00
|
|
|
request_ids: Vec<u64>,
|
|
|
|
) -> Result<Option<CachedBatch>> {
|
2023-04-24 15:59:00 +00:00
|
|
|
let request = tonic::Request::new(FilterBatchRequest {
|
|
|
|
batch_id,
|
2023-05-24 17:19:57 +00:00
|
|
|
request_ids,
|
2023-04-24 15:59:00 +00:00
|
|
|
})
|
|
|
|
.inject_context();
|
|
|
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
|
|
|
Ok(filtered_batch.batch)
|
|
|
|
}
|
|
|
|
|
2023-06-30 17:09:59 +00:00
|
|
|
/// Warmup on a max size batch
|
|
|
|
///
|
|
|
|
/// Returns the maximum amount of tokens supported by the hardware
|
2023-07-13 12:22:37 +00:00
|
|
|
#[instrument(skip_all)]
|
2023-06-30 17:09:59 +00:00
|
|
|
pub async fn warmup(
|
|
|
|
&mut self,
|
|
|
|
max_input_length: u32,
|
|
|
|
max_prefill_tokens: u32,
|
2023-10-20 08:28:45 +00:00
|
|
|
max_total_tokens: u32,
|
2024-03-27 10:59:51 +00:00
|
|
|
max_batch_total_tokens: Option<u32>,
|
2023-07-19 07:31:25 +00:00
|
|
|
) -> Result<Option<u32>> {
|
2024-02-28 09:40:13 +00:00
|
|
|
let warmup_enabled: bool = env::var("WARMUP_ENABLED").ok().map_or(true, |value| value.to_lowercase() == "true");
|
|
|
|
if !warmup_enabled {
|
|
|
|
return Ok(None);
|
|
|
|
}
|
|
|
|
|
|
|
|
let read_env_var = |key: &str, default: u32| -> u32 {
|
|
|
|
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
|
|
|
|
};
|
|
|
|
|
|
|
|
// get all possible prefill batch sizes
|
|
|
|
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
|
2024-04-01 10:44:20 +00:00
|
|
|
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
|
|
|
|
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();
|
2024-02-28 09:40:13 +00:00
|
|
|
|
|
|
|
// get all possible sequence lengths for prefill
|
|
|
|
let seq_bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
|
2024-03-27 10:59:51 +00:00
|
|
|
let mut seq_lengths: Vec<u32> = (seq_bucket_size..max_input_length+1).step_by(seq_bucket_size as usize).collect();
|
|
|
|
if let Some(&last) = seq_lengths.last() {
|
|
|
|
if last < max_input_length {
|
|
|
|
seq_lengths.push(max_input_length);
|
|
|
|
}
|
|
|
|
}
|
2024-02-28 09:40:13 +00:00
|
|
|
|
|
|
|
// execute batch for each combination of batch size and sequence length
|
|
|
|
let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len());
|
|
|
|
for batch_size in &batch_sizes {
|
|
|
|
for seq_length in &seq_lengths {
|
|
|
|
shapes.push((*batch_size, *seq_length));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
let mut id_counter: u64 = 0;
|
2024-03-27 10:59:51 +00:00
|
|
|
let num_batches = match max_batch_total_tokens {
|
|
|
|
Some(val) => {
|
|
|
|
if val == max_total_tokens {
|
|
|
|
1
|
|
|
|
} else {
|
|
|
|
2
|
|
|
|
}
|
|
|
|
}
|
|
|
|
None => 2, // If max_batch_total_tokens is None, create two batches
|
|
|
|
};
|
2024-02-28 09:40:13 +00:00
|
|
|
for shape in shapes.iter() {
|
|
|
|
// create two batches in order to trigger concatenate operation
|
2024-03-27 10:59:51 +00:00
|
|
|
// in case decode bs=1 create one batch
|
2024-02-28 09:40:13 +00:00
|
|
|
let batches: Vec<Batch> = vec![
|
2024-03-27 10:59:51 +00:00
|
|
|
self.create_warmup_batch(
|
|
|
|
*shape,
|
|
|
|
&mut id_counter,
|
|
|
|
max_input_length,
|
|
|
|
max_total_tokens,
|
|
|
|
seq_bucket_size,
|
|
|
|
false,
|
|
|
|
);
|
|
|
|
num_batches
|
2024-02-28 09:40:13 +00:00
|
|
|
];
|
|
|
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
|
|
|
let _response = self.stub.warmup(request).await?.into_inner();
|
|
|
|
}
|
|
|
|
|
2024-03-22 22:43:20 +00:00
|
|
|
//Send batches with deafult params to warm up Greedy search
|
|
|
|
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
|
|
|
|
for batch_size in &batch_sizes {
|
|
|
|
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
|
|
|
|
}
|
|
|
|
for greedy_shape in greedy_shapes.iter() {
|
|
|
|
let batches: Vec<Batch> = vec![
|
2024-03-27 10:59:51 +00:00
|
|
|
self.create_warmup_batch(
|
|
|
|
*greedy_shape,
|
|
|
|
&mut id_counter,
|
|
|
|
max_input_length,
|
|
|
|
max_total_tokens,
|
|
|
|
seq_bucket_size,
|
|
|
|
true,
|
|
|
|
);
|
|
|
|
num_batches
|
2024-03-22 22:43:20 +00:00
|
|
|
];
|
|
|
|
let request = tonic::Request::new(WarmupRequest { batches }).inject_context();
|
|
|
|
let _response = self.stub.warmup(request).await?.into_inner();
|
|
|
|
}
|
2024-02-28 09:40:13 +00:00
|
|
|
Ok(None) // No support for maximum total tokens
|
|
|
|
}
|
|
|
|
|
|
|
|
#[instrument(skip_all)]
|
|
|
|
fn create_warmup_batch(
|
|
|
|
&mut self,
|
|
|
|
shape: (u32, u32),
|
|
|
|
id_counter: &mut u64,
|
|
|
|
max_input_length: u32,
|
|
|
|
max_total_tokens: u32,
|
|
|
|
seq_bucket_size: u32,
|
2024-03-22 22:43:20 +00:00
|
|
|
default_params: bool,
|
2024-02-28 09:40:13 +00:00
|
|
|
) -> Batch {
|
|
|
|
*id_counter += 1;
|
|
|
|
let (batch_size, input_length) = shape;
|
2023-06-30 17:09:59 +00:00
|
|
|
let mut requests = Vec::new();
|
2024-02-28 09:40:13 +00:00
|
|
|
for request_id in 0..batch_size {
|
2024-03-22 22:43:20 +00:00
|
|
|
let req_params = if default_params {
|
|
|
|
Some(NextTokenChooserParameters {
|
|
|
|
temperature: 1.0,
|
|
|
|
top_k: 0,
|
|
|
|
top_p: 1.0,
|
|
|
|
typical_p: 1.0,
|
|
|
|
do_sample: false,
|
|
|
|
seed: 0,
|
|
|
|
repetition_penalty: 1.0,
|
|
|
|
watermark: false,
|
|
|
|
})
|
|
|
|
} else {
|
|
|
|
Some(NextTokenChooserParameters {
|
2024-03-18 14:17:47 +00:00
|
|
|
temperature: 0.9,
|
|
|
|
top_k: 10,
|
|
|
|
top_p: 0.9,
|
|
|
|
typical_p: 0.9,
|
|
|
|
do_sample: true,
|
2023-06-30 17:09:59 +00:00
|
|
|
seed: 0,
|
2024-03-18 14:17:47 +00:00
|
|
|
repetition_penalty: 1.2,
|
2024-03-27 12:32:20 +00:00
|
|
|
watermark: false,
|
2024-03-22 22:43:20 +00:00
|
|
|
})
|
|
|
|
};
|
|
|
|
requests.push(Request {
|
|
|
|
id: *id_counter + request_id as u64,
|
|
|
|
inputs: self.get_random_input(input_length, seq_bucket_size),
|
|
|
|
truncate: max_input_length,
|
|
|
|
parameters: req_params,
|
2023-06-30 17:09:59 +00:00
|
|
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
2024-03-13 15:19:40 +00:00
|
|
|
max_new_tokens: cmp::min(10, max_total_tokens - max_input_length),
|
2023-06-30 17:09:59 +00:00
|
|
|
stop_sequences: vec![],
|
2023-10-20 08:28:45 +00:00
|
|
|
ignore_eos_token: true,
|
2023-06-30 17:09:59 +00:00
|
|
|
}),
|
2024-02-28 09:40:13 +00:00
|
|
|
prefill_logprobs: false,
|
|
|
|
top_n_tokens: 0,
|
2023-06-30 17:09:59 +00:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-28 09:40:13 +00:00
|
|
|
Batch {
|
|
|
|
id: *id_counter,
|
2023-06-30 17:09:59 +00:00
|
|
|
size: requests.len() as u32,
|
|
|
|
requests,
|
2024-02-28 09:40:13 +00:00
|
|
|
max_tokens: max_total_tokens,
|
|
|
|
}
|
|
|
|
}
|
2023-06-30 17:09:59 +00:00
|
|
|
|
2024-02-28 09:40:13 +00:00
|
|
|
#[instrument(skip_all)]
|
|
|
|
fn get_random_input(
|
|
|
|
&mut self,
|
|
|
|
input_length: u32,
|
|
|
|
seq_bucket_size: u32,
|
|
|
|
) -> String {
|
|
|
|
let skip_tokenizer_in_tgi: bool = env::var("SKIP_TOKENIZER_IN_TGI")
|
|
|
|
.ok()
|
|
|
|
.map_or(false, |value| value.to_lowercase() == "true");
|
|
|
|
if skip_tokenizer_in_tgi {
|
|
|
|
// generate random tokens
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let range = Uniform::new(2, 8192);
|
2024-03-27 10:59:51 +00:00
|
|
|
let tokens = if input_length % seq_bucket_size == 0 {
|
|
|
|
input_length - seq_bucket_size / 2
|
|
|
|
} else {
|
|
|
|
input_length - (input_length % seq_bucket_size) / 2
|
|
|
|
};
|
2024-02-28 09:40:13 +00:00
|
|
|
(0..tokens)
|
|
|
|
.map(|_| rng.sample(&range).to_string())
|
|
|
|
.collect::<Vec<String>>()
|
|
|
|
.join(", ")
|
|
|
|
} else {
|
|
|
|
// repeat test string to get expected input shape
|
2024-03-27 10:59:51 +00:00
|
|
|
let mut bucket_id = input_length / seq_bucket_size;
|
|
|
|
if input_length % seq_bucket_size != 0 {
|
|
|
|
bucket_id += 1
|
|
|
|
}
|
2024-02-28 09:40:13 +00:00
|
|
|
let repeats = cmp::max(1, (bucket_id - 1) * seq_bucket_size / 2);
|
|
|
|
"_test ".to_string().repeat(repeats as usize)
|
|
|
|
}
|
2023-06-30 17:09:59 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Generate one token for each request in the given batch
|
|
|
|
///
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Returns Generation for each request in batch
|
2022-10-18 13:19:03 +00:00
|
|
|
/// and the next cached batch
|
2023-02-13 12:02:45 +00:00
|
|
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
2023-05-24 17:19:57 +00:00
|
|
|
pub async fn prefill(
|
|
|
|
&mut self,
|
|
|
|
batch: Batch,
|
|
|
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
2023-02-13 12:02:45 +00:00
|
|
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
|
|
|
let response = self.stub.prefill(request).await?.into_inner();
|
2023-01-31 16:04:00 +00:00
|
|
|
Ok((response.generations, response.batch))
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Generate one token for each request in the given cached batches
|
2022-10-18 13:19:03 +00:00
|
|
|
///
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Returns Generation for each request in batches
|
2022-10-18 13:19:03 +00:00
|
|
|
/// and the next cached batch
|
2023-02-13 12:02:45 +00:00
|
|
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub async fn decode(
|
2022-10-08 10:30:12 +00:00
|
|
|
&mut self,
|
2023-05-24 17:19:57 +00:00
|
|
|
batches: Vec<CachedBatch>,
|
|
|
|
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
|
2023-02-13 12:02:45 +00:00
|
|
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
|
|
|
let response = self.stub.decode(request).await?.into_inner();
|
2023-01-31 16:04:00 +00:00
|
|
|
Ok((response.generations, response.batch))
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|