Add reason to response

This commit is contained in:
OlivierDehaene 2022-12-12 17:44:37 +01:00
parent ed8ecb7ab5
commit 3eaf7ee2e3
8 changed files with 26 additions and 10 deletions

View File

@ -74,6 +74,8 @@ message GeneratedText {
string output = 2; string output = 2;
/// Number of generated tokens /// Number of generated tokens
uint32 tokens = 3; uint32 tokens = 3;
/// Finish reason
string finish_reason = 4;
} }
message GenerateRequest { message GenerateRequest {

View File

@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let response = InferResponse { let response = InferResponse {
output: output.output, output: output.output,
tokens: output.tokens, tokens: output.tokens,
finish_reason: output.finish_reason,
queued: entry.time, queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(), end: Instant::now(),
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) output: String, pub(crate) output: String,
pub(crate) tokens: u32, pub(crate) tokens: u32,
pub(crate) finish_reason: String,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) end: Instant, pub(crate) end: Instant,

View File

@ -65,6 +65,7 @@ pub(crate) struct GenerateRequest {
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct GeneratedText { pub(crate) struct GeneratedText {
pub generated_text: String, pub generated_text: String,
pub finish_reason: String,
} }
#[derive(Serialize)] #[derive(Serialize)]

View File

@ -146,6 +146,7 @@ async fn generate(
// Send response // Send response
let response = vec![GeneratedText { let response = vec![GeneratedText {
generated_text: response.output, generated_text: response.output,
finish_reason: response.finish_reason,
}]; }];
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }

View File

@ -325,14 +325,17 @@ class CausalLM(Model):
all_tokens = torch.cat([all_tokens, next_token]) all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria # Evaluate stopping criteria
if stopping_criteria(all_tokens): stop, reason = stopping_criteria(all_tokens)
if stop:
# Decode all tokens # Decode all tokens
output = self.tokenizer.decode( output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True all_tokens.squeeze(-1), skip_special_tokens=True
) )
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens) GeneratedText(
request, output, stopping_criteria.current_tokens, reason
)
) )
# add to the next batch # add to the next batch
else: else:

View File

@ -425,12 +425,15 @@ class Seq2SeqLM(Model):
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)]) decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
# Evaluate stopping criteria # Evaluate stopping criteria
if stopping_criteria(decoder_tokens): stop, reason = stopping_criteria(decoder_tokens)
if stop:
# Decode tokens # Decode tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens) GeneratedText(
request, output, stopping_criteria.current_tokens, reason
)
) )
# add to the next batch # add to the next batch
else: else:

View File

@ -32,8 +32,12 @@ class GeneratedText:
request: generate_pb2.Request request: generate_pb2.Request
output: str output: str
tokens: int tokens: int
reason: str
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText( return generate_pb2.GeneratedText(
request=self.request, output=self.output, tokens=self.tokens request=self.request,
output=self.output,
tokens=self.tokens,
finish_reason=self.reason,
) )

View File

@ -10,7 +10,7 @@ from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm from tqdm import tqdm
from typing import List from typing import List, Optional, Tuple
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
@ -99,17 +99,17 @@ class StoppingCriteria:
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
def __call__(self, all_ids): def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1 self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True return True, "length"
last_token = all_ids[-1] last_token = all_ids[-1]
for stop_sequence_criteria in self.stop_sequence_criterias: for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(last_token): if stop_sequence_criteria(last_token):
return True return True, "stop_sequence"
return False return False, None
@classmethod @classmethod
def from_pb( def from_pb(