Add WIP support for returning top tokens

Initial support returning the most probable tokens.
Note that it is currently only implemented for seq-to-seq models. It is
also always enabled, regardless of whether it is used or not.
This commit is contained in:
Vincent Brouwers 2023-07-14 19:48:15 +00:00 committed by Nicolas Patry
parent e605c2a43e
commit 8a4d2076a6
15 changed files with 187 additions and 6 deletions

View File

@ -73,6 +73,9 @@ async fn generate_runs(
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
// TODO: Implement top_n_tokens
let top_n_tokens= 0;
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
@ -82,6 +85,7 @@ async fn generate_runs(
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
@ -97,6 +101,7 @@ async fn generate_runs(
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
@ -130,6 +135,7 @@ async fn prefill(
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: u32,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
@ -145,6 +151,7 @@ async fn prefill(
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens,
})
.collect();

View File

@ -179,6 +179,9 @@ class BestOfSequence(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
# TODO: Make this optional?
top_tokens: List[List[Token]]
# `generate` details
@ -193,6 +196,9 @@ class Details(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
# TODO: Make this optional?
top_tokens: List[List[Token]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
@ -219,6 +225,9 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel):
# Generated token
token: Token
# Most likely tokens
# TODO: Make this optional?
top_tokens: List[Token]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]

View File

@ -91,6 +91,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
}
message Batch {
@ -141,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3;
}
message TopToken {
/// Token ID
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
}
message Generation {
/// Request ID
uint64 request_id = 1;
@ -156,6 +169,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Top tokens
repeated TopToken top_tokens = 8;
}
message FilterBatchRequest {

View File

@ -131,6 +131,7 @@ impl Client {
ignore_eos_token: false,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
}

View File

@ -50,6 +50,7 @@ impl Health {
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0
};
let batch = Batch {
id: BATCH_ID,

View File

@ -138,12 +138,15 @@ impl Infer {
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
@ -164,7 +167,13 @@ impl Infer {
.collect();
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
@ -172,8 +181,11 @@ impl Infer {
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
@ -191,6 +203,7 @@ impl Infer {
generated_text,
queued,
start,
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() },
})
} else {
let err = InferError::IncompleteGeneration;
@ -520,6 +533,18 @@ fn send_responses(
special: generation.token_is_special,
};
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_token in generation.top_tokens {
top_tokens.push(Token{
id: top_token.token_id,
text: top_token.token_text,
logprob: top_token.token_logprob,
special: top_token.token_is_special,
})
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
@ -527,6 +552,7 @@ fn send_responses(
entry.response_tx.send_timeout(
Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
@ -536,7 +562,7 @@ fn send_responses(
} else {
// Send message
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)),
Ok(InferStreamResponse::Intermediate{token, top_tokens}),
Duration::from_millis(10),
)?;
}
@ -566,10 +592,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
// Intermediate messages
Token(Token),
Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message
End {
token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
@ -583,6 +613,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
}
#[derive(Debug, Error)]

View File

@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
example = "null"
)]
pub seed: Option<u64>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
}
fn default_max_new_tokens() -> u32 {
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
details: false,
decoder_input_details: false,
seed: None,
top_n_tokens: None,
}
}
@ -235,6 +239,7 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
pub top_tokens: Vec<Vec<Token>>,
}
#[derive(Serialize, ToSchema)]
@ -249,6 +254,7 @@ pub(crate) struct Details {
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
pub top_tokens: Vec<Vec<Token>>,
}
#[derive(Serialize, ToSchema)]
@ -272,6 +278,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[schema(nullable = true, default = "null")]
pub top_tokens: Option<Vec<Token>>,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]

View File

@ -235,6 +235,9 @@ impl State {
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
// TODO: Actually fill this from the request
top_n_tokens: entry.request.top_n_tokens,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -328,6 +331,7 @@ mod tests {
max_new_tokens: 1,
stop_sequences: vec![],
},
top_n_tokens: 0,
},
response_tx,
span: info_span!("entry"),

View File

@ -158,7 +158,7 @@ async fn generate(
add_prompt = Some(req.inputs.clone());
}
let details = req.parameters.details || req.parameters.decoder_input_details;
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.parameters.best_of {
@ -191,6 +191,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
top_tokens: response.top_tokens,
seed: response.generated_text.seed,
}
})
@ -204,6 +205,7 @@ async fn generate(
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
})
}
false => None,
@ -374,7 +376,12 @@ async fn generate_stream(
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
<<<<<<< HEAD
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
=======
let top_n_tokens = req.0.parameters.top_n_tokens;
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
>>>>>>> 7c014c7 (Add WIP support for returning top tokens)
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream
@ -385,12 +392,16 @@ async fn generate_stream(
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens: top_n_tokens.and(Some(top_tokens)),
generated_text: None,
details: None,
};
@ -403,6 +414,7 @@ async fn generate_stream(
generated_text,
start,
queued,
top_tokens,
} => {
// Token details
let details = match details {
@ -451,6 +463,7 @@ async fn generate_stream(
let stream_token = StreamResponse {
token,
top_tokens:top_n_tokens.and(Some(top_tokens)),
generated_text: Some(output_text),
details
};

View File

@ -142,6 +142,8 @@ impl Validation {
seed,
watermark,
decoder_input_details,
// TODO: Validate top_n_tokens
top_n_tokens,
..
} = request.parameters;
@ -263,6 +265,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
stopping_parameters,
top_n_tokens: top_n_tokens.unwrap_or(0),
})
}
@ -336,6 +339,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub top_n_tokens: u32,
}
#[derive(Error, Debug)]

View File

@ -645,6 +645,7 @@ class CausalLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -1013,6 +1013,7 @@ class FlashCausalLM(Model):
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import get_top_tokens
import torch
from dataclasses import dataclass
@ -647,6 +648,16 @@ class Seq2SeqLM(Model):
all_decoder_input_ids.view(1, -1), logits[-1:, :]
)
top_tokens = get_top_tokens(
request.top_n_tokens,
logprobs,
self.all_special_ids,
self.decode_token,
all_decoder_input_ids,
prefix_offset,
read_offset,
)
# Append next token to decoder tokens
all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]
@ -706,6 +717,7 @@ class Seq2SeqLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -1,3 +1,4 @@
from functools import total_ordering
import torch
from abc import ABC, abstractmethod
@ -71,6 +72,30 @@ class PrefillTokens:
return len(self.token_ids)
@dataclass(eq=True)
@total_ordering
class TopToken:
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
def __gt__(self, other):
# We tiebreak equal logprobs with the _lower_ token_id to align with
# greedy ordering (torch.argmax)
return self.token_logprob > other.token_logprob or (
self.token_logprob == other.token_logprob and self.token_id < other.token_id
)
def to_pb(self) -> generate_pb2.TopToken:
return generate_pb2.TopToken(
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
)
@dataclass
class Generation:
request_id: int
@ -80,6 +105,8 @@ class Generation:
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[TopToken]]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -94,4 +121,7 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens]
if self.top_tokens
else None,
)

View File

@ -1,13 +1,14 @@
import re
from typing import Callable, List, Tuple, Optional
import torch
from transformers import (
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.models.types import TopToken
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import (
@ -339,3 +340,46 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
def get_top_tokens(
requested_n: int,
logprobs,
special_tokens: List[int],
decode_fn: Callable[[List[int], int, int], str],
decoder_input_ids: List[int],
prefix_offset: int,
read_offset: int,
) -> List[TopToken]:
if not requested_n:
return []
flat_scores = logprobs[-1]
# Ensure top_n doesn't exceed vocab size
top_n = min(requested_n, flat_scores.size(-1))
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
if nth_highest == -float("inf"):
nth_highest = torch.finfo(flat_scores.dtype).min
# Get indices (token ids) of all scores >= nth highest value,
# cap length at 4 * top_n as a precaution
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
top_tokens = []
for tid_tensor in top_n_indices:
tid_item = tid_tensor[0].item()
token_text, _, _ = decode_fn(
torch.cat([decoder_input_ids, tid_tensor]),
prefix_offset,
read_offset,
)
top_tokens.append(
TopToken(
token_id=tid_item,
token_logprob=logprobs[-1, tid_tensor],
token_text=token_text,
token_is_special=tid_item in special_tokens,
)
)
top_tokens.sort(reverse=True)
return top_tokens