mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add WIP support for returning top tokens
Initial support returning the most probable tokens. Note that it is currently only implemented for seq-to-seq models. It is also always enabled, regardless of whether it is used or not.
This commit is contained in:
parent
e605c2a43e
commit
8a4d2076a6
@ -73,6 +73,9 @@ async fn generate_runs(
|
|||||||
// Create a dummy sequence
|
// Create a dummy sequence
|
||||||
let sequence = create_sequence(sequence_length, tokenizer);
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
|
// TODO: Implement top_n_tokens
|
||||||
|
let top_n_tokens= 0;
|
||||||
|
|
||||||
for b in batch_size {
|
for b in batch_size {
|
||||||
// Warmups on batch size
|
// Warmups on batch size
|
||||||
for _ in 0..warmups {
|
for _ in 0..warmups {
|
||||||
@ -82,6 +85,7 @@ async fn generate_runs(
|
|||||||
b,
|
b,
|
||||||
decode_length,
|
decode_length,
|
||||||
parameters.clone(),
|
parameters.clone(),
|
||||||
|
top_n_tokens,
|
||||||
&mut client,
|
&mut client,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@ -97,6 +101,7 @@ async fn generate_runs(
|
|||||||
b,
|
b,
|
||||||
decode_length,
|
decode_length,
|
||||||
parameters.clone(),
|
parameters.clone(),
|
||||||
|
top_n_tokens,
|
||||||
&mut client,
|
&mut client,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@ -130,6 +135,7 @@ async fn prefill(
|
|||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
|
top_n_tokens: u32,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||||
// Create requests
|
// Create requests
|
||||||
@ -145,6 +151,7 @@ async fn prefill(
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
}),
|
}),
|
||||||
|
top_n_tokens: top_n_tokens,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
@ -179,6 +179,9 @@ class BestOfSequence(BaseModel):
|
|||||||
prefill: List[InputToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Most likely tokens
|
||||||
|
# TODO: Make this optional?
|
||||||
|
top_tokens: List[List[Token]]
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
@ -193,6 +196,9 @@ class Details(BaseModel):
|
|||||||
prefill: List[InputToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Most likely tokens
|
||||||
|
# TODO: Make this optional?
|
||||||
|
top_tokens: List[List[Token]]
|
||||||
# Additional sequences when using the `best_of` parameter
|
# Additional sequences when using the `best_of` parameter
|
||||||
best_of_sequences: Optional[List[BestOfSequence]]
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
@ -219,6 +225,9 @@ class StreamDetails(BaseModel):
|
|||||||
class StreamResponse(BaseModel):
|
class StreamResponse(BaseModel):
|
||||||
# Generated token
|
# Generated token
|
||||||
token: Token
|
token: Token
|
||||||
|
# Most likely tokens
|
||||||
|
# TODO: Make this optional?
|
||||||
|
top_tokens: List[Token]
|
||||||
# Complete generated text
|
# Complete generated text
|
||||||
# Only available when the generation is finished
|
# Only available when the generation is finished
|
||||||
generated_text: Optional[str]
|
generated_text: Optional[str]
|
||||||
|
@ -91,6 +91,8 @@ message Request {
|
|||||||
StoppingCriteriaParameters stopping_parameters = 5;
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
/// Return prefill logprobs
|
/// Return prefill logprobs
|
||||||
bool prefill_logprobs = 6;
|
bool prefill_logprobs = 6;
|
||||||
|
/// Return most likely n tokens
|
||||||
|
uint32 top_n_tokens = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
@ -141,6 +143,17 @@ message PrefillTokens {
|
|||||||
repeated string texts = 3;
|
repeated string texts = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message TopToken {
|
||||||
|
/// Token ID
|
||||||
|
uint32 token_id = 3;
|
||||||
|
/// Logprob
|
||||||
|
float token_logprob = 4;
|
||||||
|
/// Text
|
||||||
|
string token_text = 5;
|
||||||
|
/// Is it a special token
|
||||||
|
bool token_is_special = 6;
|
||||||
|
}
|
||||||
|
|
||||||
message Generation {
|
message Generation {
|
||||||
/// Request ID
|
/// Request ID
|
||||||
uint64 request_id = 1;
|
uint64 request_id = 1;
|
||||||
@ -156,6 +169,8 @@ message Generation {
|
|||||||
bool token_is_special = 6;
|
bool token_is_special = 6;
|
||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
optional GeneratedText generated_text = 7;
|
optional GeneratedText generated_text = 7;
|
||||||
|
/// Top tokens
|
||||||
|
repeated TopToken top_tokens = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
|
@ -131,6 +131,7 @@ impl Client {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ impl Health {
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
|
top_n_tokens: 0
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: BATCH_ID,
|
id: BATCH_ID,
|
||||||
|
@ -138,12 +138,15 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
let mut result_tokens = Vec::new();
|
let mut result_tokens = Vec::new();
|
||||||
|
let mut result_top_tokens = Vec::new();
|
||||||
let mut result_generated_text = None;
|
let mut result_generated_text = None;
|
||||||
let mut result_start = None;
|
let mut result_start = None;
|
||||||
let mut result_queued = None;
|
let mut result_queued = None;
|
||||||
@ -164,7 +167,13 @@ impl Infer {
|
|||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
InferStreamResponse::Intermediate{
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
} => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
|
}
|
||||||
// Final message
|
// Final message
|
||||||
// Set return values
|
// Set return values
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
@ -172,8 +181,11 @@ impl Infer {
|
|||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
|
top_tokens,
|
||||||
|
|
||||||
} => {
|
} => {
|
||||||
result_tokens.push(token);
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
result_generated_text = Some(generated_text);
|
result_generated_text = Some(generated_text);
|
||||||
result_start = Some(start);
|
result_start = Some(start);
|
||||||
result_queued = Some(queued)
|
result_queued = Some(queued)
|
||||||
@ -191,6 +203,7 @@ impl Infer {
|
|||||||
generated_text,
|
generated_text,
|
||||||
queued,
|
queued,
|
||||||
start,
|
start,
|
||||||
|
top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() },
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
@ -520,6 +533,18 @@ fn send_responses(
|
|||||||
special: generation.token_is_special,
|
special: generation.token_is_special,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// generation.top_tokens
|
||||||
|
let mut top_tokens = Vec::new();
|
||||||
|
for top_token in generation.top_tokens {
|
||||||
|
top_tokens.push(Token{
|
||||||
|
id: top_token.token_id,
|
||||||
|
text: top_token.token_text,
|
||||||
|
logprob: top_token.token_logprob,
|
||||||
|
special: top_token.token_is_special,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
@ -527,6 +552,7 @@ fn send_responses(
|
|||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::End {
|
Ok(InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
|
top_tokens,
|
||||||
generated_text,
|
generated_text,
|
||||||
queued: entry.queue_time,
|
queued: entry.queue_time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
@ -536,7 +562,7 @@ fn send_responses(
|
|||||||
} else {
|
} else {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::Token(token)),
|
Ok(InferStreamResponse::Intermediate{token, top_tokens}),
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
@ -566,10 +592,14 @@ pub(crate) enum InferStreamResponse {
|
|||||||
// Optional first message
|
// Optional first message
|
||||||
Prefill(PrefillTokens),
|
Prefill(PrefillTokens),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
Token(Token),
|
Intermediate {
|
||||||
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
|
},
|
||||||
// Last message
|
// Last message
|
||||||
End {
|
End {
|
||||||
token: Token,
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
generated_text: GeneratedText,
|
generated_text: GeneratedText,
|
||||||
start: Instant,
|
start: Instant,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
@ -583,6 +613,7 @@ pub(crate) struct InferResponse {
|
|||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
|
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = "null"
|
example = "null"
|
||||||
)]
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
|
pub top_n_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
fn default_max_new_tokens() -> u32 {
|
||||||
@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
details: false,
|
details: false,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
top_n_tokens: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,6 +239,7 @@ pub(crate) struct BestOfSequence {
|
|||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Vec<PrefillToken>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Vec<Token>,
|
pub tokens: Vec<Token>,
|
||||||
|
pub top_tokens: Vec<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -249,6 +254,7 @@ pub(crate) struct Details {
|
|||||||
pub tokens: Vec<Token>,
|
pub tokens: Vec<Token>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
|
pub top_tokens: Vec<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -272,6 +278,8 @@ pub(crate) struct StreamDetails {
|
|||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
|
#[schema(nullable = true, default = "null")]
|
||||||
|
pub top_tokens: Option<Vec<Token>>,
|
||||||
#[schema(nullable = true, default = "null", example = "test")]
|
#[schema(nullable = true, default = "null", example = "test")]
|
||||||
pub generated_text: Option<String>,
|
pub generated_text: Option<String>,
|
||||||
#[schema(nullable = true, default = "null")]
|
#[schema(nullable = true, default = "null")]
|
||||||
|
@ -235,6 +235,9 @@ impl State {
|
|||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
|
// TODO: Actually fill this from the request
|
||||||
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
|
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -328,6 +331,7 @@ mod tests {
|
|||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
|
top_n_tokens: 0,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
@ -158,7 +158,7 @@ async fn generate(
|
|||||||
add_prompt = Some(req.inputs.clone());
|
add_prompt = Some(req.inputs.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.parameters.best_of {
|
let (response, best_of_responses) = match req.parameters.best_of {
|
||||||
@ -191,6 +191,7 @@ async fn generate(
|
|||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: response.prefill,
|
prefill: response.prefill,
|
||||||
tokens: response.tokens,
|
tokens: response.tokens,
|
||||||
|
top_tokens: response.top_tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -204,6 +205,7 @@ async fn generate(
|
|||||||
tokens: response.tokens,
|
tokens: response.tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
best_of_sequences,
|
best_of_sequences,
|
||||||
|
top_tokens: response.top_tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
false => None,
|
false => None,
|
||||||
@ -374,7 +376,12 @@ async fn generate_stream(
|
|||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
} else {
|
} else {
|
||||||
|
<<<<<<< HEAD
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
|
=======
|
||||||
|
let top_n_tokens = req.0.parameters.top_n_tokens;
|
||||||
|
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
|
>>>>>>> 7c014c7 (Add WIP support for returning top tokens)
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, mut response_stream)) => {
|
Ok((_permit, mut response_stream)) => {
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
@ -385,12 +392,16 @@ async fn generate_stream(
|
|||||||
// Prefill is ignored
|
// Prefill is ignored
|
||||||
InferStreamResponse::Prefill(_) => {}
|
InferStreamResponse::Prefill(_) => {}
|
||||||
// Yield event for every new token
|
// Yield event for every new token
|
||||||
InferStreamResponse::Token(token) => {
|
InferStreamResponse::Intermediate{
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
} => {
|
||||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
|
top_tokens: top_n_tokens.and(Some(top_tokens)),
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
@ -403,6 +414,7 @@ async fn generate_stream(
|
|||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
@ -451,6 +463,7 @@ async fn generate_stream(
|
|||||||
|
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
|
top_tokens:top_n_tokens.and(Some(top_tokens)),
|
||||||
generated_text: Some(output_text),
|
generated_text: Some(output_text),
|
||||||
details
|
details
|
||||||
};
|
};
|
||||||
|
@ -142,6 +142,8 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
|
// TODO: Validate top_n_tokens
|
||||||
|
top_n_tokens,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -263,6 +265,7 @@ impl Validation {
|
|||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -336,6 +339,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: NextTokenChooserParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
|
pub top_n_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
|
@ -645,6 +645,7 @@ class CausalLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -1013,6 +1013,7 @@ class FlashCausalLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id in self.all_special_ids,
|
next_token_id in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.utils.tokens import get_top_tokens
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -647,6 +648,16 @@ class Seq2SeqLM(Model):
|
|||||||
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
all_decoder_input_ids.view(1, -1), logits[-1:, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
top_tokens = get_top_tokens(
|
||||||
|
request.top_n_tokens,
|
||||||
|
logprobs,
|
||||||
|
self.all_special_ids,
|
||||||
|
self.decode_token,
|
||||||
|
all_decoder_input_ids,
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
)
|
||||||
|
|
||||||
# Append next token to decoder tokens
|
# Append next token to decoder tokens
|
||||||
all_decoder_input_ids = torch.cat(
|
all_decoder_input_ids = torch.cat(
|
||||||
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
[all_decoder_input_ids, next_token_id.squeeze(1)]
|
||||||
@ -706,6 +717,7 @@ class Seq2SeqLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import total_ordering
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -71,6 +72,30 @@ class PrefillTokens:
|
|||||||
return len(self.token_ids)
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=True)
|
||||||
|
@total_ordering
|
||||||
|
class TopToken:
|
||||||
|
token_id: int
|
||||||
|
token_logprob: float
|
||||||
|
token_text: str
|
||||||
|
token_is_special: bool
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
# We tiebreak equal logprobs with the _lower_ token_id to align with
|
||||||
|
# greedy ordering (torch.argmax)
|
||||||
|
return self.token_logprob > other.token_logprob or (
|
||||||
|
self.token_logprob == other.token_logprob and self.token_id < other.token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.TopToken:
|
||||||
|
return generate_pb2.TopToken(
|
||||||
|
token_id=self.token_id,
|
||||||
|
token_logprob=self.token_logprob,
|
||||||
|
token_text=self.token_text,
|
||||||
|
token_is_special=self.token_is_special,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Generation:
|
class Generation:
|
||||||
request_id: int
|
request_id: int
|
||||||
@ -80,6 +105,8 @@ class Generation:
|
|||||||
token_text: str
|
token_text: str
|
||||||
token_is_special: bool
|
token_is_special: bool
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
|
# Optional for now, since it's not yet supported for every model.
|
||||||
|
top_tokens: Optional[List[TopToken]]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
@ -94,4 +121,7 @@ class Generation:
|
|||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
|
top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens]
|
||||||
|
if self.top_tokens
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import re
|
import re
|
||||||
|
from typing import Callable, List, Tuple, Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models.types import TopToken
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
@ -339,3 +340,46 @@ class HeterogeneousSampling:
|
|||||||
self.greedy_indices = new_greedy_indices
|
self.greedy_indices = new_greedy_indices
|
||||||
self.sampling_mapping = new_sampling_mapping
|
self.sampling_mapping = new_sampling_mapping
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def get_top_tokens(
|
||||||
|
requested_n: int,
|
||||||
|
logprobs,
|
||||||
|
special_tokens: List[int],
|
||||||
|
decode_fn: Callable[[List[int], int, int], str],
|
||||||
|
decoder_input_ids: List[int],
|
||||||
|
prefix_offset: int,
|
||||||
|
read_offset: int,
|
||||||
|
) -> List[TopToken]:
|
||||||
|
if not requested_n:
|
||||||
|
return []
|
||||||
|
|
||||||
|
flat_scores = logprobs[-1]
|
||||||
|
# Ensure top_n doesn't exceed vocab size
|
||||||
|
top_n = min(requested_n, flat_scores.size(-1))
|
||||||
|
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
|
||||||
|
nth_highest = torch.topk(flat_scores, top_n)[0][-1]
|
||||||
|
if nth_highest == -float("inf"):
|
||||||
|
nth_highest = torch.finfo(flat_scores.dtype).min
|
||||||
|
# Get indices (token ids) of all scores >= nth highest value,
|
||||||
|
# cap length at 4 * top_n as a precaution
|
||||||
|
top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)]
|
||||||
|
top_tokens = []
|
||||||
|
for tid_tensor in top_n_indices:
|
||||||
|
tid_item = tid_tensor[0].item()
|
||||||
|
token_text, _, _ = decode_fn(
|
||||||
|
torch.cat([decoder_input_ids, tid_tensor]),
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
)
|
||||||
|
top_tokens.append(
|
||||||
|
TopToken(
|
||||||
|
token_id=tid_item,
|
||||||
|
token_logprob=logprobs[-1, tid_tensor],
|
||||||
|
token_text=token_text,
|
||||||
|
token_is_special=tid_item in special_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
top_tokens.sort(reverse=True)
|
||||||
|
return top_tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user