mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
252 lines
9.1 KiB
Rust
252 lines
9.1 KiB
Rust
/// Single shard Client
|
|
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
|
use crate::pb::generate::v2::*;
|
|
use crate::Result;
|
|
use grpc_metadata::InjectTelemetryContext;
|
|
use std::cmp::min;
|
|
use std::time::Duration;
|
|
use tonic::transport::{Channel, Uri};
|
|
use tracing::instrument;
|
|
|
|
/// Text Generation Inference gRPC client
|
|
#[derive(Debug, Clone)]
|
|
pub struct Client {
|
|
stub: TextGenerationServiceClient<Channel>,
|
|
}
|
|
|
|
impl Client {
|
|
/// Returns a client connected to the given url
|
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
|
let channel = Channel::builder(uri).connect().await?;
|
|
|
|
Ok(Self {
|
|
stub: TextGenerationServiceClient::new(channel),
|
|
})
|
|
}
|
|
|
|
/// Returns a client connected to the given unix socket
|
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
|
.unwrap()
|
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
|
tokio::net::UnixStream::connect(path.clone())
|
|
}))
|
|
.await?;
|
|
|
|
Ok(Self {
|
|
stub: TextGenerationServiceClient::new(channel),
|
|
})
|
|
}
|
|
|
|
/// Returns a list of uris or unix sockets of all shards
|
|
#[instrument(skip(self))]
|
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
|
let response = self.stub.service_discovery(request).await?;
|
|
let urls = response
|
|
.into_inner()
|
|
.urls
|
|
.into_iter()
|
|
// Remove unix socket prefix
|
|
.map(|url| match url.strip_prefix("unix://") {
|
|
None => url,
|
|
Some(stripped_url) => stripped_url.to_string(),
|
|
})
|
|
.collect();
|
|
Ok(urls)
|
|
}
|
|
|
|
/// Get model info
|
|
#[instrument(skip(self))]
|
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
|
let response = self.stub.info(request).await?.into_inner();
|
|
Ok(response)
|
|
}
|
|
|
|
/// Get model health
|
|
#[instrument(skip(self))]
|
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
|
let response = self.stub.health(request).await?.into_inner();
|
|
Ok(response)
|
|
}
|
|
|
|
/// Clear the past generations cache
|
|
#[instrument(skip(self))]
|
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
|
self.stub.clear_cache(request).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Filter a cached batch
|
|
#[instrument(skip(self))]
|
|
pub async fn filter_batch(
|
|
&mut self,
|
|
batch_id: u64,
|
|
request_ids: Vec<u64>,
|
|
) -> Result<Option<CachedBatch>> {
|
|
let request = tonic::Request::new(FilterBatchRequest {
|
|
batch_id,
|
|
request_ids,
|
|
})
|
|
.inject_context();
|
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
|
Ok(filtered_batch.batch)
|
|
}
|
|
|
|
/// Warmup on a max size batch
|
|
///
|
|
/// Returns the maximum amount of tokens supported by the hardware
|
|
#[instrument(skip_all)]
|
|
pub async fn warmup(
|
|
&mut self,
|
|
max_input_length: u32,
|
|
max_prefill_tokens: u32,
|
|
max_total_tokens: u32,
|
|
max_batch_size: Option<usize>,
|
|
) -> Result<Option<u32>> {
|
|
let mut n_tokens = 0;
|
|
let mut requests = Vec::new();
|
|
// Create requests
|
|
while n_tokens < max_prefill_tokens {
|
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
|
|
|
let mut inputs = String::new();
|
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
|
if n_tokens == 0 {
|
|
// 1 request is enough to test vision heads.
|
|
// Sending images on other queries messes up easily with truncation.
|
|
inputs.push_str("");
|
|
}
|
|
|
|
requests.push(Request {
|
|
id: 0,
|
|
// We truncate the input on the server side to be sure that it has the correct size
|
|
inputs,
|
|
truncate,
|
|
// Set sampling parameters to also take these ops into account in the max memory
|
|
parameters: Some(NextTokenChooserParameters {
|
|
temperature: 0.9,
|
|
top_k: 10,
|
|
top_p: 0.9,
|
|
typical_p: 0.9,
|
|
do_sample: false,
|
|
seed: 0,
|
|
repetition_penalty: 1.2,
|
|
frequency_penalty: 0.1,
|
|
watermark: true,
|
|
grammar: String::new(),
|
|
grammar_type: GrammarType::None as i32,
|
|
}),
|
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
|
max_new_tokens: max_total_tokens - truncate,
|
|
stop_sequences: vec![],
|
|
ignore_eos_token: true,
|
|
}),
|
|
prefill_logprobs: true,
|
|
top_n_tokens: 20,
|
|
});
|
|
n_tokens += max_input_length;
|
|
|
|
// Check max_batch_size
|
|
if Some(requests.len()) == max_batch_size {
|
|
break;
|
|
}
|
|
}
|
|
|
|
let batch = Batch {
|
|
id: 0,
|
|
size: requests.len() as u32,
|
|
requests,
|
|
max_tokens: 0,
|
|
};
|
|
|
|
let request = tonic::Request::new(WarmupRequest {
|
|
batch: Some(batch),
|
|
max_input_length,
|
|
max_prefill_tokens,
|
|
max_total_tokens,
|
|
})
|
|
.inject_context();
|
|
let response = self.stub.warmup(request).await?.into_inner();
|
|
Ok(response.max_supported_total_tokens)
|
|
}
|
|
|
|
/// Generate one token for each request in the given batch
|
|
///
|
|
/// Returns Generation for each request in batch
|
|
/// and the next cached batch
|
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
|
pub async fn prefill(
|
|
&mut self,
|
|
batch: Batch,
|
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
|
let response = self.stub.prefill(request).await?.into_inner();
|
|
Ok((
|
|
response.generations,
|
|
response.batch,
|
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
|
))
|
|
}
|
|
|
|
/// Generate one token for each request in the given cached batches
|
|
///
|
|
/// Returns Generation for each request in batches
|
|
/// and the next cached batch
|
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
|
pub async fn decode(
|
|
&mut self,
|
|
batches: Vec<CachedBatch>,
|
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
|
let response = self.stub.decode(request).await?.into_inner();
|
|
Ok((
|
|
response.generations,
|
|
response.batch,
|
|
DecodeTimings::new(
|
|
response.concat_ns,
|
|
response.forward_ns,
|
|
response.decode_ns,
|
|
response.total_ns,
|
|
),
|
|
))
|
|
}
|
|
}
|
|
|
|
pub struct PrefillTimings {
|
|
pub forward: Duration,
|
|
pub decode: Duration,
|
|
pub total: Duration,
|
|
}
|
|
|
|
impl PrefillTimings {
|
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
|
Self {
|
|
forward: Duration::from_nanos(forward_ns),
|
|
decode: Duration::from_nanos(decode_ns),
|
|
total: Duration::from_nanos(total_ns),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct DecodeTimings {
|
|
pub concat: Option<Duration>,
|
|
pub forward: Duration,
|
|
pub decode: Duration,
|
|
pub total: Duration,
|
|
}
|
|
|
|
impl DecodeTimings {
|
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
|
Self {
|
|
concat: concat_ns.map(Duration::from_nanos),
|
|
forward: Duration::from_nanos(forward_ns),
|
|
decode: Duration::from_nanos(decode_ns),
|
|
total: Duration::from_nanos(total_ns),
|
|
}
|
|
}
|
|
}
|