text-generation-inference/router/client/src/client.rs

234 lines
7.7 KiB
Rust
Raw Normal View History

2022-10-18 13:19:03 +00:00
/// Single shard Client
2023-12-11 11:46:30 +00:00
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v2::*;
2022-10-08 10:30:12 +00:00
use crate::Result;
2023-02-13 12:02:45 +00:00
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use std::time::Duration;
2022-10-08 10:30:12 +00:00
use tonic::transport::{Channel, Uri};
2023-02-13 12:02:45 +00:00
use tracing::instrument;
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
2022-10-08 10:30:12 +00:00
pub struct Client {
2022-10-17 12:59:00 +00:00
stub: TextGenerationServiceClient<Channel>,
2022-10-08 10:30:12 +00:00
}
impl Client {
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
2022-10-08 10:30:12 +00:00
2022-10-17 12:59:00 +00:00
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
2022-10-08 10:30:12 +00:00
}
2022-10-17 12:59:00 +00:00
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
2022-10-08 10:30:12 +00:00
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
2022-10-17 12:59:00 +00:00
.await?;
2022-10-08 10:30:12 +00:00
2022-10-17 12:59:00 +00:00
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Returns a list of uris or unix sockets of all shards
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
2023-02-13 12:02:45 +00:00
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?;
2022-10-08 10:30:12 +00:00
let urls = response
.into_inner()
.urls
.into_iter()
2022-10-18 13:19:03 +00:00
// Remove unix socket prefix
2022-10-08 10:30:12 +00:00
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
2022-10-18 13:19:03 +00:00
/// Clear the past generations cache
2022-10-08 10:30:12 +00:00
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
2023-02-13 12:02:45 +00:00
self.stub.clear_cache(request).await?;
2022-10-08 10:30:12 +00:00
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
2023-10-23 13:51:12 +00:00
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs: "_test ".to_string().repeat(max_input_length as usize),
2023-10-23 13:51:12 +00:00
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
watermark: true,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
Rebased #617 (#868) # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <vincent.brouwers@ing.com>
2023-08-28 09:43:47 +00:00
top_n_tokens: 20,
});
n_tokens += max_input_length;
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
2022-10-18 13:19:03 +00:00
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
2022-10-18 13:19:03 +00:00
/// and the next cached batch
2023-02-13 12:02:45 +00:00
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
2023-02-13 12:02:45 +00:00
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
2022-10-08 10:30:12 +00:00
}
/// Generate one token for each request in the given cached batches
2022-10-18 13:19:03 +00:00
///
/// Returns Generation for each request in batches
2022-10-18 13:19:03 +00:00
/// and the next cached batch
2023-02-13 12:02:45 +00:00
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
2022-10-08 10:30:12 +00:00
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
2023-02-13 12:02:45 +00:00
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(|v| Duration::from_nanos(v)),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
2022-10-08 10:30:12 +00:00
}