mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add WIP support for returning top tokens
Initial support returning the most probable tokens. Note that it is currently only implemented for seq-to-seq models. It is also always enabled, regardless of whether it is used or not.
This commit is contained in:
parent
e605c2a43e
commit
8a4d2076a6
@ -73,6 +73,9 @@ async fn generate_runs(
|
||||
// Create a dummy sequence
|
||||
let sequence = create_sequence(sequence_length, tokenizer);
|
||||
|
||||
// TODO: Implement top_n_tokens
|
||||
let top_n_tokens= 0;
|
||||
|
||||
for b in batch_size {
|
||||
// Warmups on batch size
|
||||
for _ in 0..warmups {
|
||||
@ -82,6 +85,7 @@ async fn generate_runs(
|
||||
b,
|
||||
decode_length,
|
||||
parameters.clone(),
|
||||
top_n_tokens,
|
||||
&mut client,
|
||||
)
|
||||
.await?;
|
||||
@ -97,6 +101,7 @@ async fn generate_runs(
|
||||
b,
|
||||
decode_length,
|
||||
parameters.clone(),
|
||||
top_n_tokens,
|
||||
&mut client,
|
||||
)
|
||||
.await?;
|
||||
@ -130,6 +135,7 @@ async fn prefill(
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
parameters: NextTokenChooserParameters,
|
||||
top_n_tokens: u32,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||
// Create requests
|
||||
@ -145,6 +151,7 @@ async fn prefill(
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||
}),
|
||||
top_n_tokens: top_n_tokens,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -179,6 +179,9 @@ class BestOfSequence(BaseModel):
|
||||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
# TODO: Make this optional?
|
||||
top_tokens: List[List[Token]]
|
||||
|
||||
|
||||
# `generate` details
|
||||
@ -193,6 +196,9 @@ class Details(BaseModel):
|
||||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
# TODO: Make this optional?
|
||||
top_tokens: List[List[Token]]
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
best_of_sequences: Optional[List[BestOfSequence]]
|
||||
|
||||
@ -219,6 +225,9 @@ class StreamDetails(BaseModel):
|
||||
class StreamResponse(BaseModel):
|
||||
# Generated token
|
||||
token: Token
|
||||
# Most likely tokens
|
||||
# TODO: Make this optional?
|
||||
top_tokens: List[Token]
|
||||
# Complete generated text
|
||||
# Only available when the generation is finished
|
||||
generated_text: Optional[str]
|
||||
|
@ -91,6 +91,8 @@ message Request {
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
@ -141,6 +143,17 @@ message PrefillTokens {
|
||||
repeated string texts = 3;
|
||||
}
|
||||
|
||||
message TopToken {
|
||||
/// Token ID
|
||||
uint32 token_id = 3;
|
||||
/// Logprob
|
||||
float token_logprob = 4;
|
||||
/// Text
|
||||
string token_text = 5;
|
||||
/// Is it a special token
|
||||
bool token_is_special = 6;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
@ -156,6 +169,8 @@ message Generation {
|
||||
bool token_is_special = 6;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 7;
|
||||
/// Top tokens
|
||||
repeated TopToken top_tokens = 8;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
|
@ -131,6 +131,7 @@ impl Client {
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ impl Health {
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0
|
||||
};
|
||||
let batch = Batch {
|
||||
id: BATCH_ID,
|
||||
|
@ -138,12 +138,15 @@ impl Infer {
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
let mut result_tokens = Vec::new();
|
||||
let mut result_top_tokens = Vec::new();
|
||||
let mut result_generated_text = None;
|
||||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
@ -164,7 +167,13 @@ impl Infer {
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||
InferStreamResponse::Intermediate{
|
||||
token,
|
||||
top_tokens,
|
||||
} => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
}
|
||||
// Final message
|
||||
// Set return values
|
||||
InferStreamResponse::End {
|
||||
@ -172,8 +181,11 @@ impl Infer {
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
|
||||
} => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
result_generated_text = Some(generated_text);
|
||||
result_start = Some(start);
|
||||
result_queued = Some(queued)
|
||||
@ -191,6 +203,7 @@ impl Infer {
|
||||
generated_text,
|
||||
queued,
|
||||
start,
|
||||
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() },
|
||||
})
|
||||
} else {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
@ -520,6 +533,18 @@ fn send_responses(
|
||||
special: generation.token_is_special,
|
||||
};
|
||||
|
||||
|
||||
// generation.top_tokens
|
||||
let mut top_tokens = Vec::new();
|
||||
for top_token in generation.top_tokens {
|
||||
top_tokens.push(Token{
|
||||
id: top_token.token_id,
|
||||
text: top_token.token_text,
|
||||
logprob: top_token.token_logprob,
|
||||
special: top_token.token_is_special,
|
||||
})
|
||||
}
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
@ -527,6 +552,7 @@ fn send_responses(
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
@ -536,7 +562,7 @@ fn send_responses(
|
||||
} else {
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Token(token)),
|
||||
Ok(InferStreamResponse::Intermediate{token, top_tokens}),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
}
|
||||
@ -566,10 +592,14 @@ pub(crate) enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(PrefillTokens),
|
||||
// Intermediate messages
|
||||
Token(Token),
|
||||
Intermediate {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
},
|
||||
// Last message
|
||||
End {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
generated_text: GeneratedText,
|
||||
start: Instant,
|
||||
queued: Instant,
|
||||
@ -583,6 +613,7 @@ pub(crate) struct InferResponse {
|
||||
pub(crate) generated_text: GeneratedText,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
|
||||
example = "null"
|
||||
)]
|
||||
pub seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
details: false,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
top_n_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,6 +239,7 @@ pub(crate) struct BestOfSequence {
|
||||
pub seed: Option<u64>,
|
||||
pub prefill: Vec<PrefillToken>,
|
||||
pub tokens: Vec<Token>,
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
@ -249,6 +254,7 @@ pub(crate) struct Details {
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
@ -272,6 +278,8 @@ pub(crate) struct StreamDetails {
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub token: Token,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
pub top_tokens: Option<Vec<Token>>,
|
||||
#[schema(nullable = true, default = "null", example = "test")]
|
||||
pub generated_text: Option<String>,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
|
@ -235,6 +235,9 @@ impl State {
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
// TODO: Actually fill this from the request
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
@ -328,6 +331,7 @@ mod tests {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
@ -158,7 +158,7 @@ async fn generate(
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
|
||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
||||
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
@ -191,6 +191,7 @@ async fn generate(
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
prefill: response.prefill,
|
||||
tokens: response.tokens,
|
||||
top_tokens: response.top_tokens,
|
||||
seed: response.generated_text.seed,
|
||||
}
|
||||
})
|
||||
@ -204,6 +205,7 @@ async fn generate(
|
||||
tokens: response.tokens,
|
||||
seed: response.generated_text.seed,
|
||||
best_of_sequences,
|
||||
top_tokens: response.top_tokens,
|
||||
})
|
||||
}
|
||||
false => None,
|
||||
@ -374,7 +376,12 @@ async fn generate_stream(
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else {
|
||||
<<<<<<< HEAD
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
=======
|
||||
let top_n_tokens = req.0.parameters.top_n_tokens;
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
>>>>>>> 7c014c7 (Add WIP support for returning top tokens)
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
// Server-Sent Event stream
|
||||
@ -385,12 +392,16 @@ async fn generate_stream(
|
||||
// Prefill is ignored
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Token(token) => {
|
||||
InferStreamResponse::Intermediate{
|
||||
token,
|
||||
top_tokens,
|
||||
} => {
|
||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens: top_n_tokens.and(Some(top_tokens)),
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
@ -403,6 +414,7 @@ async fn generate_stream(
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
@ -451,6 +463,7 @@ async fn generate_stream(
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens:top_n_tokens.and(Some(top_tokens)),
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
|
@ -142,6 +142,8 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
decoder_input_details,
|
||||
// TODO: Validate top_n_tokens
|
||||
top_n_tokens,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -263,6 +265,7 @@ impl Validation {
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
})
|
||||
}
|
||||
|
||||
@ -336,6 +339,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||
pub decoder_input_details: bool,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
pub top_n_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -645,6 +645,7 @@ class CausalLM(Model):
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -1013,6 +1013,7 @@ class FlashCausalLM(Model):
|
||||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from text_generation_server.utils.tokens import get_top_tokens
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -647,6 +648,16 @@ class Seq2SeqLM(Model):
|
||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||
)
|
||||
|
||||
top_tokens = get_top_tokens(
|
||||
request.top_n_tokens,
|
||||
logprobs,
|
||||
self.all_special_ids,
|
||||
self.decode_token,
|
||||
all_decoder_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
all_decoder_input_ids = torch.cat(
|
||||
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
||||
@ -706,6 +717,7 @@ class Seq2SeqLM(Model):
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from functools import total_ordering
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@ -71,6 +72,30 @@ class PrefillTokens:
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass(eq=True)
|
||||
@total_ordering
|
||||
class TopToken:
|
||||
token_id: int
|
||||
token_logprob: float
|
||||
token_text: str
|
||||
token_is_special: bool
|
||||
|
||||
def __gt__(self, other):
|
||||
# We tiebreak equal logprobs with the _lower_ token_id to align with
|
||||
# greedy ordering (torch.argmax)
|
||||
return self.token_logprob > other.token_logprob or (
|
||||
self.token_logprob == other.token_logprob and self.token_id < other.token_id
|
||||
)
|
||||
|
||||
def to_pb(self) -> generate_pb2.TopToken:
|
||||
return generate_pb2.TopToken(
|
||||
token_id=self.token_id,
|
||||
token_logprob=self.token_logprob,
|
||||
token_text=self.token_text,
|
||||
token_is_special=self.token_is_special,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
request_id: int
|
||||
@ -80,6 +105,8 @@ class Generation:
|
||||
token_text: str
|
||||
token_is_special: bool
|
||||
generated_text: Optional[GeneratedText]
|
||||
# Optional for now, since it's not yet supported for every model.
|
||||
top_tokens: Optional[List[TopToken]]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
@ -94,4 +121,7 @@ class Generation:
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens]
|
||||
if self.top_tokens
|
||||
else None,
|
||||
)
|
||||
|
@ -1,13 +1,14 @@
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Optional
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.types import TopToken
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.utils.logits_process import (
|
||||
@ -339,3 +340,46 @@ class HeterogeneousSampling:
|
||||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
||||
|
||||
def get_top_tokens(
|
||||
requested_n: int,
|
||||
logprobs,
|
||||
special_tokens: List[int],
|
||||
decode_fn: Callable[[List[int], int, int], str],
|
||||
decoder_input_ids: List[int],
|
||||
prefix_offset: int,
|
||||
read_offset: int,
|
||||
) -> List[TopToken]:
|
||||
if not requested_n:
|
||||
return []
|
||||
|
||||
flat_scores = logprobs[-1]
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n = min(requested_n, flat_scores.size(-1))
|
||||
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
||||
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
|
||||
if nth_highest == -float("inf"):
|
||||
nth_highest = torch.finfo(flat_scores.dtype).min
|
||||
# Get indices (token ids) of all scores >= nth highest value,
|
||||
# cap length at 4 * top_n as a precaution
|
||||
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
|
||||
top_tokens = []
|
||||
for tid_tensor in top_n_indices:
|
||||
tid_item = tid_tensor[0].item()
|
||||
token_text, _, _ = decode_fn(
|
||||
torch.cat([decoder_input_ids, tid_tensor]),
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
top_tokens.append(
|
||||
TopToken(
|
||||
token_id=tid_item,
|
||||
token_logprob=logprobs[-1, tid_tensor],
|
||||
token_text=token_text,
|
||||
token_is_special=tid_item in special_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
top_tokens.sort(reverse=True)
|
||||
return top_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user