Modifying the protobuf.

This commit is contained in:
Nicolas Patry 2023-11-29 16:20:11 +00:00
parent 866af9b9fd
commit a478b276eb
10 changed files with 123 additions and 143 deletions

View File

@ -135,43 +135,27 @@ message GeneratedText {
optional uint64 seed = 4;
}
message PrefillTokens {
/// Prefill Token IDs
message Tokens {
/// Token IDs
repeated uint32 ids = 1;
/// Prefill Logprobs
/// Logprobs
repeated float logprobs = 2;
/// Prefill tokens
/// tokens
repeated string texts = 3;
}
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
/// special
repeated bool is_special = 4;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
PrefillTokens prefill_tokens = 2;
/// Token ID
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
Tokens prefill_tokens = 2;
Tokens tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 7;
optional GeneratedText generated_text = 4;
/// Top tokens
TopTokens top_tokens = 8;
repeated Tokens top_tokens = 5;
}
message FilterBatchRequest {

View File

@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
PrefillTokens, Request, StoppingCriteriaParameters,
Tokens, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -9,7 +9,7 @@ use std::sync::{
Arc,
};
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
Batch, CachedBatch, ClientError, GeneratedText, Generation, Tokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
@ -167,21 +167,21 @@ impl Infer {
.collect();
}
// Push last token
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
InferStreamResponse::Intermediate { tokens, top_tokens } => {
result_tokens.extend(tokens);
result_top_tokens.extend(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
token,
tokens,
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_tokens.extend(tokens);
result_top_tokens.extend(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
@ -523,18 +523,27 @@ fn send_responses(
}
// Create last Token
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens{
tokens_.ids.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}).collect()
}else{
vec![]
};
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
for top_tokens_ in generation.top_tokens{
let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
top_tokens_
.ids
.into_iter()
@ -547,7 +556,8 @@ fn send_responses(
logprob,
special,
}),
)
);
top_tokens.push(local_top_tokens);
}
if let Some(generated_text) = generation.generated_text {
@ -555,7 +565,7 @@ fn send_responses(
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
tokens,
top_tokens,
generated_text,
queued: entry.queue_time,
@ -565,7 +575,7 @@ fn send_responses(
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
}
Ok(stopped)
}
@ -591,16 +601,16 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
Prefill(Tokens),
// Intermediate messages
Intermediate {
token: Token,
top_tokens: Vec<Token>,
tokens: Vec<Token>,
top_tokens: Vec<Vec<Token>>,
},
// Last message
End {
token: Token,
top_tokens: Vec<Token>,
tokens: Vec<Token>,
top_tokens: Vec<Vec<Token>>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,

View File

@ -279,9 +279,10 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
pub top_tokens: Vec<Vec<Token>>,
pub text: String,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]

View File

@ -388,14 +388,15 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Intermediate{
token,
tokens,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token);
tracing::debug!(parent: &span, "Tokens: {:?}", tokens);
// StreamResponse
let stream_token = StreamResponse {
token,
tokens,
text,
top_tokens,
generated_text: None,
details: None,
@ -405,7 +406,8 @@ async fn generate_stream(
}
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
tokens,
text,
generated_text,
start,
queued,
@ -457,8 +459,9 @@ async fn generate_stream(
tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse {
token,
tokens,
top_tokens,
text
generated_text: Some(output_text),
details
};

View File

@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -676,8 +675,8 @@ class CausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
)
else:
prefill_tokens = None
@ -691,7 +690,7 @@ class CausalLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
@ -703,10 +702,12 @@ class CausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)

View File

@ -14,10 +14,9 @@ from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.models.cache_manager import (
get_cache_manager,
@ -952,14 +951,17 @@ class FlashCausalLM(Model):
# Append next token to all tokens
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
all_input_ids.extend(_next_token_ids)
next_token_texts = []
for j in range(index, index + n_accepted_ids):
# Generated token
all_input_ids.append(next_token_ids[j])
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
# Evaluate stopping criteria
@ -1013,8 +1015,8 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
)
else:
prefill_tokens = None
@ -1028,7 +1030,7 @@ class FlashCausalLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
@ -1037,16 +1039,15 @@ class FlashCausalLM(Model):
else:
top_tokens = None
next_token_ids = _next_token_ids[0]
next_token_logprob = _next_token_logprobs[0]
generation = Generation(
request.id,
prefill_tokens,
next_token_id,
next_token_logprob,
next_token_text,
next_token_id in self.all_special_ids,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)

View File

@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Tokens,
Generation,
GeneratedText,
)
@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
)
else:
prefill_tokens = None
@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)

View File

@ -11,8 +11,7 @@ from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
TopTokens,
Tokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = PrefillTokens(
prefill_tokens = Tokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
[false]
)
else:
prefill_tokens = None
@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
generation = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)

View File

@ -58,33 +58,15 @@ class GeneratedText:
@dataclass
class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
def __len__(self):
return len(self.token_ids)
@dataclass
class TopTokens:
class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
)
def __len__(self):
@ -94,14 +76,11 @@ class TopTokens:
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
prefill_tokens: Optional[Tokens]
tokens: Tokens
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -109,10 +88,7 @@ class Generation:
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,