From 3eaf7ee2e398931d7d826bfe36f8ffceab4bcc79 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 12 Dec 2022 17:44:37 +0100 Subject: [PATCH] Add reason to response --- proto/generate.proto | 2 ++ router/src/batcher.rs | 2 ++ router/src/lib.rs | 1 + router/src/server.rs | 1 + server/text_generation/models/causal_lm.py | 7 +++++-- server/text_generation/models/seq2seq_lm.py | 7 +++++-- server/text_generation/models/types.py | 6 +++++- server/text_generation/utils.py | 10 +++++----- 8 files changed, 26 insertions(+), 10 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 78abcea4..0c67de03 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -74,6 +74,8 @@ message GeneratedText { string output = 2; /// Number of generated tokens uint32 tokens = 3; + /// Finish reason + string finish_reason = 4; } message GenerateRequest { diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 0c85a406..a72a6e44 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -190,6 +190,7 @@ fn send_generated(finished: Vec, db: &Db) { let response = InferResponse { output: output.output, tokens: output.tokens, + finish_reason: output.finish_reason, queued: entry.time, start: entry.batch_time.unwrap(), // unwrap is always valid end: Instant::now(), @@ -203,6 +204,7 @@ fn send_generated(finished: Vec, db: &Db) { pub(crate) struct InferResponse { pub(crate) output: String, pub(crate) tokens: u32, + pub(crate) finish_reason: String, pub(crate) queued: Instant, pub(crate) start: Instant, pub(crate) end: Instant, diff --git a/router/src/lib.rs b/router/src/lib.rs index b7ea4a8f..b6c694ee 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -65,6 +65,7 @@ pub(crate) struct GenerateRequest { #[derive(Serialize)] pub(crate) struct GeneratedText { pub generated_text: String, + pub finish_reason: String, } #[derive(Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 45952ed0..59296269 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -146,6 +146,7 @@ async fn generate( // Send response let response = vec![GeneratedText { generated_text: response.output, + finish_reason: response.finish_reason, }]; Ok((headers, Json(response))) } diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 49bf7e5d..72095858 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -325,14 +325,17 @@ class CausalLM(Model): all_tokens = torch.cat([all_tokens, next_token]) # Evaluate stopping criteria - if stopping_criteria(all_tokens): + stop, reason = stopping_criteria(all_tokens) + if stop: # Decode all tokens output = self.tokenizer.decode( all_tokens.squeeze(-1), skip_special_tokens=True ) # Add to the list of finished generations with the original request generated_texts.append( - GeneratedText(request, output, stopping_criteria.current_tokens) + GeneratedText( + request, output, stopping_criteria.current_tokens, reason + ) ) # add to the next batch else: diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 46e11a9a..93e31b4a 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -425,12 +425,15 @@ class Seq2SeqLM(Model): decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)]) # Evaluate stopping criteria - if stopping_criteria(decoder_tokens): + stop, reason = stopping_criteria(decoder_tokens) + if stop: # Decode tokens output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) # Add to the list of finished generations with the original request generated_texts.append( - GeneratedText(request, output, stopping_criteria.current_tokens) + GeneratedText( + request, output, stopping_criteria.current_tokens, reason + ) ) # add to the next batch else: diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 7c25bf67..91ec75d0 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -32,8 +32,12 @@ class GeneratedText: request: generate_pb2.Request output: str tokens: int + reason: str def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( - request=self.request, output=self.output, tokens=self.tokens + request=self.request, + output=self.output, + tokens=self.tokens, + finish_reason=self.reason, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 2d6a6505..661df05e 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -10,7 +10,7 @@ from functools import partial from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache from huggingface_hub.utils import LocalEntryNotFoundError from tqdm import tqdm -from typing import List +from typing import List, Optional, Tuple from transformers import AutoTokenizer from transformers.generation.logits_process import ( LogitsProcessorList, @@ -99,17 +99,17 @@ class StoppingCriteria: self.max_new_tokens = max_new_tokens self.current_tokens = 0 - def __call__(self, all_ids): + def __call__(self, all_ids) -> Tuple[bool, Optional[str]]: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: - return True + return True, "length" last_token = all_ids[-1] for stop_sequence_criteria in self.stop_sequence_criterias: if stop_sequence_criteria(last_token): - return True + return True, "stop_sequence" - return False + return False, None @classmethod def from_pb(