mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Add reason to response
This commit is contained in:
parent
ed8ecb7ab5
commit
3eaf7ee2e3
@ -74,6 +74,8 @@ message GeneratedText {
|
|||||||
string output = 2;
|
string output = 2;
|
||||||
/// Number of generated tokens
|
/// Number of generated tokens
|
||||||
uint32 tokens = 3;
|
uint32 tokens = 3;
|
||||||
|
/// Finish reason
|
||||||
|
string finish_reason = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateRequest {
|
message GenerateRequest {
|
||||||
|
@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||||||
let response = InferResponse {
|
let response = InferResponse {
|
||||||
output: output.output,
|
output: output.output,
|
||||||
tokens: output.tokens,
|
tokens: output.tokens,
|
||||||
|
finish_reason: output.finish_reason,
|
||||||
queued: entry.time,
|
queued: entry.time,
|
||||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||||
end: Instant::now(),
|
end: Instant::now(),
|
||||||
@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
pub(crate) output: String,
|
pub(crate) output: String,
|
||||||
pub(crate) tokens: u32,
|
pub(crate) tokens: u32,
|
||||||
|
pub(crate) finish_reason: String,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
pub(crate) end: Instant,
|
pub(crate) end: Instant,
|
||||||
|
@ -65,6 +65,7 @@ pub(crate) struct GenerateRequest {
|
|||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct GeneratedText {
|
pub(crate) struct GeneratedText {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
|
pub finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
|
@ -146,6 +146,7 @@ async fn generate(
|
|||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText {
|
let response = vec![GeneratedText {
|
||||||
generated_text: response.output,
|
generated_text: response.output,
|
||||||
|
finish_reason: response.finish_reason,
|
||||||
}];
|
}];
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
@ -325,14 +325,17 @@ class CausalLM(Model):
|
|||||||
all_tokens = torch.cat([all_tokens, next_token])
|
all_tokens = torch.cat([all_tokens, next_token])
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
if stopping_criteria(all_tokens):
|
stop, reason = stopping_criteria(all_tokens)
|
||||||
|
if stop:
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output = self.tokenizer.decode(
|
output = self.tokenizer.decode(
|
||||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||||
)
|
)
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
GeneratedText(
|
||||||
|
request, output, stopping_criteria.current_tokens, reason
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
else:
|
else:
|
||||||
|
@ -425,12 +425,15 @@ class Seq2SeqLM(Model):
|
|||||||
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
if stopping_criteria(decoder_tokens):
|
stop, reason = stopping_criteria(decoder_tokens)
|
||||||
|
if stop:
|
||||||
# Decode tokens
|
# Decode tokens
|
||||||
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(request, output, stopping_criteria.current_tokens)
|
GeneratedText(
|
||||||
|
request, output, stopping_criteria.current_tokens, reason
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
else:
|
else:
|
||||||
|
@ -32,8 +32,12 @@ class GeneratedText:
|
|||||||
request: generate_pb2.Request
|
request: generate_pb2.Request
|
||||||
output: str
|
output: str
|
||||||
tokens: int
|
tokens: int
|
||||||
|
reason: str
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
return generate_pb2.GeneratedText(
|
return generate_pb2.GeneratedText(
|
||||||
request=self.request, output=self.output, tokens=self.tokens
|
request=self.request,
|
||||||
|
output=self.output,
|
||||||
|
tokens=self.tokens,
|
||||||
|
finish_reason=self.reason,
|
||||||
)
|
)
|
||||||
|
@ -10,7 +10,7 @@ from functools import partial
|
|||||||
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
from huggingface_hub.utils import LocalEntryNotFoundError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import List
|
from typing import List, Optional, Tuple
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
@ -99,17 +99,17 @@ class StoppingCriteria:
|
|||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
|
|
||||||
def __call__(self, all_ids):
|
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True
|
return True, "length"
|
||||||
|
|
||||||
last_token = all_ids[-1]
|
last_token = all_ids[-1]
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
if stop_sequence_criteria(last_token):
|
if stop_sequence_criteria(last_token):
|
||||||
return True
|
return True, "stop_sequence"
|
||||||
|
|
||||||
return False
|
return False, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
|
Loading…
Reference in New Issue
Block a user