Add reason to response

This commit is contained in:
OlivierDehaene 2022-12-12 17:44:37 +01:00
parent ed8ecb7ab5
commit 3eaf7ee2e3
8 changed files with 26 additions and 10 deletions

View File

@ -74,6 +74,8 @@ message GeneratedText {
string output = 2;
/// Number of generated tokens
uint32 tokens = 3;
/// Finish reason
string finish_reason = 4;
}
message GenerateRequest {

View File

@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let response = InferResponse {
output: output.output,
tokens: output.tokens,
finish_reason: output.finish_reason,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) tokens: u32,
pub(crate) finish_reason: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,

View File

@ -65,6 +65,7 @@ pub(crate) struct GenerateRequest {
#[derive(Serialize)]
pub(crate) struct GeneratedText {
pub generated_text: String,
pub finish_reason: String,
}
#[derive(Serialize)]

View File

@ -146,6 +146,7 @@ async fn generate(
// Send response
let response = vec![GeneratedText {
generated_text: response.output,
finish_reason: response.finish_reason,
}];
Ok((headers, Json(response)))
}

View File

@ -325,14 +325,17 @@ class CausalLM(Model):
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
stop, reason = stopping_criteria(all_tokens)
if stop:
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens)
GeneratedText(
request, output, stopping_criteria.current_tokens, reason
)
)
# add to the next batch
else:

View File

@ -425,12 +425,15 @@ class Seq2SeqLM(Model):
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
# Evaluate stopping criteria
if stopping_criteria(decoder_tokens):
stop, reason = stopping_criteria(decoder_tokens)
if stop:
# Decode tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens)
GeneratedText(
request, output, stopping_criteria.current_tokens, reason
)
)
# add to the next batch
else:

View File

@ -32,8 +32,12 @@ class GeneratedText:
request: generate_pb2.Request
output: str
tokens: int
reason: str
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
request=self.request, output=self.output, tokens=self.tokens
request=self.request,
output=self.output,
tokens=self.tokens,
finish_reason=self.reason,
)

View File

@ -10,7 +10,7 @@ from functools import partial
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from transformers.generation.logits_process import (
LogitsProcessorList,
@ -99,17 +99,17 @@ class StoppingCriteria:
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
def __call__(self, all_ids):
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True
return True, "length"
last_token = all_ids[-1]
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(last_token):
return True
return True, "stop_sequence"
return False
return False, None
@classmethod
def from_pb(