Modifying the protobuf.

This commit is contained in:
Nicolas Patry 2023-11-29 16:20:11 +00:00
parent 866af9b9fd
commit a478b276eb
10 changed files with 123 additions and 143 deletions

View File

@ -135,43 +135,27 @@ message GeneratedText {
optional uint64 seed = 4; optional uint64 seed = 4;
} }
message PrefillTokens { message Tokens {
/// Prefill Token IDs /// Token IDs
repeated uint32 ids = 1; repeated uint32 ids = 1;
/// Prefill Logprobs /// Logprobs
repeated float logprobs = 2; repeated float logprobs = 2;
/// Prefill tokens /// tokens
repeated string texts = 3; repeated string texts = 3;
} /// special
repeated bool is_special = 4;
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
} }
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
/// Prefill tokens (optional) /// Prefill tokens (optional)
PrefillTokens prefill_tokens = 2; Tokens prefill_tokens = 2;
/// Token ID Tokens tokens = 3;
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 7; optional GeneratedText generated_text = 4;
/// Top tokens /// Top tokens
TopTokens top_tokens = 8; repeated Tokens top_tokens = 5;
} }
message FilterBatchRequest { message FilterBatchRequest {

View File

@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
PrefillTokens, Request, StoppingCriteriaParameters, Tokens, Request, StoppingCriteriaParameters,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -9,7 +9,7 @@ use std::sync::{
Arc, Arc,
}; };
use text_generation_client::{ use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, CachedBatch, ClientError, GeneratedText, Generation, Tokens, ShardedClient,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
@ -167,21 +167,21 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Intermediate { token, top_tokens } => { InferStreamResponse::Intermediate { tokens, top_tokens } => {
result_tokens.push(token); result_tokens.extend(tokens);
result_top_tokens.push(top_tokens); result_top_tokens.extend(top_tokens);
} }
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
token, tokens,
generated_text, generated_text,
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
result_tokens.push(token); result_tokens.extend(tokens);
result_top_tokens.push(top_tokens); result_top_tokens.extend(top_tokens);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
@ -523,18 +523,27 @@ fn send_responses(
} }
// Create last Token // Create last Token
let token = Token { let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens{
id: generation.token_id, tokens_.ids.into_iter()
text: generation.token_text, .zip(tokens_.logprobs.into_iter())
logprob: generation.token_logprob, .zip(tokens_.texts.into_iter())
special: generation.token_is_special, .zip(tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}).collect()
}else{
vec![]
}; };
// generation.top_tokens // generation.top_tokens
let mut top_tokens = Vec::new(); let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens { for top_tokens_ in generation.top_tokens{
top_tokens.extend( let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
top_tokens_ top_tokens_
.ids .ids
.into_iter() .into_iter()
@ -547,7 +556,8 @@ fn send_responses(
logprob, logprob,
special, special,
}), }),
) );
top_tokens.push(local_top_tokens);
} }
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
@ -555,7 +565,7 @@ fn send_responses(
stopped = true; stopped = true;
// Send message // Send message
entry.response_tx.send(Ok(InferStreamResponse::End { entry.response_tx.send(Ok(InferStreamResponse::End {
token, tokens,
top_tokens, top_tokens,
generated_text, generated_text,
queued: entry.queue_time, queued: entry.queue_time,
@ -565,7 +575,7 @@ fn send_responses(
// Send message // Send message
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; .send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
} }
Ok(stopped) Ok(stopped)
} }
@ -591,16 +601,16 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(PrefillTokens), Prefill(Tokens),
// Intermediate messages // Intermediate messages
Intermediate { Intermediate {
token: Token, tokens: Vec<Token>,
top_tokens: Vec<Token>, top_tokens: Vec<Vec<Token>>,
}, },
// Last message // Last message
End { End {
token: Token, tokens: Vec<Token>,
top_tokens: Vec<Token>, top_tokens: Vec<Vec<Token>>,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,

View File

@ -279,9 +279,10 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub token: Token, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>, pub top_tokens: Vec<Vec<Token>>,
pub text: String,
#[schema(nullable = true, default = "null", example = "test")] #[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")] #[schema(nullable = true, default = "null")]

View File

@ -388,14 +388,15 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Intermediate{ InferStreamResponse::Intermediate{
token, tokens,
top_tokens, top_tokens,
} => { } => {
tracing::debug!(parent: &span, "Token: {:?}", token); tracing::debug!(parent: &span, "Tokens: {:?}", tokens);
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, tokens,
text,
top_tokens, top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
@ -405,7 +406,8 @@ async fn generate_stream(
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
token, tokens,
text,
generated_text, generated_text,
start, start,
queued, queued,
@ -457,8 +459,9 @@ async fn generate_stream(
tracing::info!(parent: &span, "Success"); tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, tokens,
top_tokens, top_tokens,
text
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}; };

View File

@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -676,8 +675,8 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -691,7 +690,7 @@ class CausalLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
@ -703,10 +702,12 @@ class CausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )

View File

@ -14,10 +14,9 @@ from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
@ -952,14 +951,17 @@ class FlashCausalLM(Model):
# Append next token to all tokens # Append next token to all tokens
_next_token_ids = next_token_ids[index: index+n_accepted_ids] _next_token_ids = next_token_ids[index: index+n_accepted_ids]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids] _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
all_input_ids.extend(_next_token_ids)
next_token_texts = []
for j in range(index, index + n_accepted_ids):
# Generated token # Generated token
all_input_ids.append(next_token_ids[j])
next_token_text, prefix_offset, read_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids, all_input_ids,
prefix_offset, prefix_offset,
read_offset, read_offset,
) )
next_token_texts.append(next_token_text)
# Evaluate stopping criteria # Evaluate stopping criteria
@ -1013,8 +1015,8 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -1028,7 +1030,7 @@ class FlashCausalLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
@ -1037,16 +1039,15 @@ class FlashCausalLM(Model):
else: else:
top_tokens = None top_tokens = None
next_token_ids = _next_token_ids[0]
next_token_logprob = _next_token_logprobs[0]
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id, Tokens(
next_token_logprob, _next_token_ids,
next_token_text, _next_token_logprobs,
next_token_id in self.all_special_ids, next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )

View File

@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
PrefillTokens, Tokens,
Generation, Generation,
GeneratedText, GeneratedText,
) )
@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )

View File

@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
Batch, Batch,
Generation, Generation,
PrefillTokens, Tokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens( prefill_tokens = Tokens(
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
[self.tokenizer.bos_token], [self.tokenizer.bos_token],
[false]
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids for token_id in top_token_ids
] ]
top_tokens = TopTokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
toptoken_texts, toptoken_texts,
@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
next_token_id_squeezed, Tokens(
next_token_logprob, [next_token_id_squeezed],
next_token_text, [next_token_logprob],
next_token_id_squeezed.item() in self.all_special_ids, [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text, generated_text,
top_tokens, top_tokens,
) )

View File

@ -58,33 +58,15 @@ class GeneratedText:
@dataclass @dataclass
class PrefillTokens: class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int] token_ids: List[int]
logprobs: List[float] logprobs: List[float]
texts: List[str] texts: List[str]
is_special: List[bool] is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens: def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.TopTokens( return generate_pb2.Tokens(
ids=self.token_ids, ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
) )
def __len__(self): def __len__(self):
@ -94,14 +76,11 @@ class TopTokens:
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
prefill_tokens: Optional[PrefillTokens] prefill_tokens: Optional[Tokens]
token_id: int tokens: Tokens
token_logprob: float
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model. # Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens] top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
@ -109,10 +88,7 @@ class Generation:
prefill_tokens=self.prefill_tokens.to_pb() prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None if self.prefill_tokens is not None
else None, else None,
token_id=self.token_id, tokens=self.tokens.to_pb(),
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,