From 57e57e6fee9484112f85c92249aca0aa7a91fea5 Mon Sep 17 00:00:00 2001 From: Nikola Borisov Date: Wed, 30 Aug 2023 15:20:47 -0700 Subject: [PATCH] Return num input tokens (#3) returning number of input tokens in the details message Co-authored-by: Yessen Kanapin --- Dockerfile | 7 ++++++- clients/python/text_generation/types.py | 2 ++ router/src/infer.rs | 4 ++++ router/src/lib.rs | 6 ++++++ router/src/server.rs | 10 ++++++++-- .../models/flash_causal_lm.py | 14 +++++++++----- 6 files changed, 35 insertions(+), 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index 73a009ae..e507454c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -182,12 +182,17 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/co RUN pip install einops --no-cache-dir # Install server +COPY server/requirements.txt server/requirements.txt +COPY server/pyproject.toml server/pyproject.toml +COPY server/poetry.lock server/poetry.lock +RUN cd server && \ + pip install -r requirements.txt + COPY proto proto COPY server server COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ - pip install -r requirements.txt && \ pip install ".[bnb, accelerate, quantize]" --no-cache-dir # Install benchmarker diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 548f0b63..23b9308f 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -211,6 +211,8 @@ class StreamDetails(BaseModel): finish_reason: FinishReason # Number of generated tokens generated_tokens: int + # Number of input tokens + input_tokens: int # Sampling seed if sampling was activated seed: Optional[int] diff --git a/router/src/infer.rs b/router/src/infer.rs index 188ddc64..aac77920 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -147,6 +147,7 @@ impl Infer { let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; + let mut number_input_tokens = 0; // Iterate on stream while let Some(response) = stream.next().await { @@ -155,6 +156,7 @@ impl Infer { InferStreamResponse::Prefill(tokens) => { // Create Token objects // We do that here instead of in the Python code as Rust for loops are faster + number_input_tokens = tokens.ids.len() as u32; result_prefill = tokens .ids .into_iter() @@ -188,6 +190,7 @@ impl Infer { Ok(InferResponse { prefill: result_prefill, tokens: result_tokens, + input_tokens: number_input_tokens, generated_text, queued, start, @@ -581,6 +584,7 @@ pub(crate) struct InferResponse { pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, + pub(crate) input_tokens: u32, pub(crate) queued: Instant, pub(crate) start: Instant, } diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dff7a11..7ec80de2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -231,6 +231,8 @@ pub(crate) struct BestOfSequence { pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, + #[schema(example = 100)] + pub input_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, pub prefill: Vec, @@ -243,6 +245,8 @@ pub(crate) struct Details { pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, + #[schema(example = 100)] + pub input_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, pub prefill: Vec, @@ -265,6 +269,8 @@ pub(crate) struct StreamDetails { pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, + #[schema(example = 100)] + pub input_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, } diff --git a/router/src/server.rs b/router/src/server.rs index 9af94951..1b1879cd 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -191,6 +191,7 @@ async fn generate( response.generated_text.finish_reason, ), generated_tokens: response.generated_text.generated_tokens, + input_tokens: response.input_tokens, prefill: response.prefill, tokens: response.tokens, seed: response.generated_text.seed, @@ -202,6 +203,7 @@ async fn generate( Some(Details { finish_reason: FinishReason::from(response.generated_text.finish_reason), generated_tokens: response.generated_text.generated_tokens, + input_tokens: response.input_tokens, prefill: response.prefill, tokens: response.tokens, seed: response.generated_text.seed, @@ -380,12 +382,15 @@ async fn generate_stream( // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { // Server-Sent Event stream + let mut number_input_tokens = 0; while let Some(response) = response_stream.next().await { match response { Ok(response) => { match response { - // Prefill is ignored - InferStreamResponse::Prefill(_) => {} + // Prefill is only used for initial num input tokens + InferStreamResponse::Prefill(prefill_tokens) => { + number_input_tokens = prefill_tokens.ids.len() as u32; + } // Yield event for every new token InferStreamResponse::Token(token) => { tracing::debug!(parent: &span, "Token: {:?}", token); @@ -411,6 +416,7 @@ async fn generate_stream( true => Some(StreamDetails { finish_reason: FinishReason::from(generated_text.finish_reason), generated_tokens: generated_text.generated_tokens, + input_tokens: number_input_tokens, seed: generated_text.seed, }), false => None, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7de51358..1e0292d1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -985,14 +985,18 @@ class FlashCausalLM(Model): generated_text = None # Prefill - if prefill and request.prefill_logprobs: + if prefill: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + if request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + request_prefill_logprobs = [float("nan")] + prefill_logprobs[ + out_start_index : out_end_index - 1 + ] + else: + request_prefill_logprobs = [] + prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids,