mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Add reason to response
This commit is contained in:
parent
ed8ecb7ab5
commit
3eaf7ee2e3
@ -74,6 +74,8 @@ message GeneratedText {
|
||||
string output = 2;
|
||||
/// Number of generated tokens
|
||||
uint32 tokens = 3;
|
||||
/// Finish reason
|
||||
string finish_reason = 4;
|
||||
}
|
||||
|
||||
message GenerateRequest {
|
||||
|
@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||
let response = InferResponse {
|
||||
output: output.output,
|
||||
tokens: output.tokens,
|
||||
finish_reason: output.finish_reason,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||
end: Instant::now(),
|
||||
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) output: String,
|
||||
pub(crate) tokens: u32,
|
||||
pub(crate) finish_reason: String,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) end: Instant,
|
||||
|
@ -65,6 +65,7 @@ pub(crate) struct GenerateRequest {
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub generated_text: String,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -146,6 +146,7 @@ async fn generate(
|
||||
// Send response
|
||||
let response = vec![GeneratedText {
|
||||
generated_text: response.output,
|
||||
finish_reason: response.finish_reason,
|
||||
}];
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
@ -325,14 +325,17 @@ class CausalLM(Model):
|
||||
all_tokens = torch.cat([all_tokens, next_token])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
if stopping_criteria(all_tokens):
|
||||
stop, reason = stopping_criteria(all_tokens)
|
||||
if stop:
|
||||
# Decode all tokens
|
||||
output = self.tokenizer.decode(
|
||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||
)
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
||||
GeneratedText(
|
||||
request, output, stopping_criteria.current_tokens, reason
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
|
@ -425,12 +425,15 @@ class Seq2SeqLM(Model):
|
||||
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
if stopping_criteria(decoder_tokens):
|
||||
stop, reason = stopping_criteria(decoder_tokens)
|
||||
if stop:
|
||||
# Decode tokens
|
||||
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
||||
GeneratedText(
|
||||
request, output, stopping_criteria.current_tokens, reason
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
else:
|
||||
|
@ -32,8 +32,12 @@ class GeneratedText:
|
||||
request: generate_pb2.Request
|
||||
output: str
|
||||
tokens: int
|
||||
reason: str
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(
|
||||
request=self.request, output=self.output, tokens=self.tokens
|
||||
request=self.request,
|
||||
output=self.output,
|
||||
tokens=self.tokens,
|
||||
finish_reason=self.reason,
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ from functools import partial
|
||||
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
@ -99,17 +99,17 @@ class StoppingCriteria:
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.current_tokens = 0
|
||||
|
||||
def __call__(self, all_ids):
|
||||
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
|
||||
self.current_tokens += 1
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True
|
||||
return True, "length"
|
||||
|
||||
last_token = all_ids[-1]
|
||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||
if stop_sequence_criteria(last_token):
|
||||
return True
|
||||
return True, "stop_sequence"
|
||||
|
||||
return False
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
|
Loading…
Reference in New Issue
Block a user