mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Modifying the protobuf.
This commit is contained in:
parent
866af9b9fd
commit
a478b276eb
@ -135,43 +135,27 @@ message GeneratedText {
|
||||
optional uint64 seed = 4;
|
||||
}
|
||||
|
||||
message PrefillTokens {
|
||||
/// Prefill Token IDs
|
||||
message Tokens {
|
||||
/// Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Prefill Logprobs
|
||||
/// Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// Prefill tokens
|
||||
/// tokens
|
||||
repeated string texts = 3;
|
||||
}
|
||||
|
||||
message TopTokens {
|
||||
/// Top Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Top Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// Top Token Texts
|
||||
repeated string texts = 3;
|
||||
/// If the tokens are special
|
||||
repeated bool is_special = 6;
|
||||
/// special
|
||||
repeated bool is_special = 4;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
/// Prefill tokens (optional)
|
||||
PrefillTokens prefill_tokens = 2;
|
||||
/// Token ID
|
||||
uint32 token_id = 3;
|
||||
/// Logprob
|
||||
float token_logprob = 4;
|
||||
/// Text
|
||||
string token_text = 5;
|
||||
/// Is it a special token
|
||||
bool token_is_special = 6;
|
||||
Tokens prefill_tokens = 2;
|
||||
Tokens tokens = 3;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 7;
|
||||
optional GeneratedText generated_text = 4;
|
||||
/// Top tokens
|
||||
TopTokens top_tokens = 8;
|
||||
repeated Tokens top_tokens = 5;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||
PrefillTokens, Request, StoppingCriteriaParameters,
|
||||
Tokens, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -9,7 +9,7 @@ use std::sync::{
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, Tokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
@ -167,21 +167,21 @@ impl Infer {
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
InferStreamResponse::Intermediate { tokens, top_tokens } => {
|
||||
result_tokens.extend(tokens);
|
||||
result_top_tokens.extend(top_tokens);
|
||||
}
|
||||
// Final message
|
||||
// Set return values
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
tokens,
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
result_tokens.extend(tokens);
|
||||
result_top_tokens.extend(top_tokens);
|
||||
result_generated_text = Some(generated_text);
|
||||
result_start = Some(start);
|
||||
result_queued = Some(queued)
|
||||
@ -523,18 +523,27 @@ fn send_responses(
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let token = Token {
|
||||
id: generation.token_id,
|
||||
text: generation.token_text,
|
||||
logprob: generation.token_logprob,
|
||||
special: generation.token_is_special,
|
||||
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens{
|
||||
tokens_.ids.into_iter()
|
||||
.zip(tokens_.logprobs.into_iter())
|
||||
.zip(tokens_.texts.into_iter())
|
||||
.zip(tokens_.is_special.into_iter())
|
||||
.map(|(((id, logprob), text), special)| Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
}).collect()
|
||||
}else{
|
||||
vec![]
|
||||
};
|
||||
|
||||
// generation.top_tokens
|
||||
|
||||
let mut top_tokens = Vec::new();
|
||||
if let Some(top_tokens_) = generation.top_tokens {
|
||||
top_tokens.extend(
|
||||
for top_tokens_ in generation.top_tokens{
|
||||
let mut local_top_tokens = Vec::new();
|
||||
local_top_tokens.extend(
|
||||
top_tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
@ -547,7 +556,8 @@ fn send_responses(
|
||||
logprob,
|
||||
special,
|
||||
}),
|
||||
)
|
||||
);
|
||||
top_tokens.push(local_top_tokens);
|
||||
}
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
@ -555,7 +565,7 @@ fn send_responses(
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
tokens,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
@ -565,7 +575,7 @@ fn send_responses(
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
|
||||
}
|
||||
Ok(stopped)
|
||||
}
|
||||
@ -591,16 +601,16 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(PrefillTokens),
|
||||
Prefill(Tokens),
|
||||
// Intermediate messages
|
||||
Intermediate {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
tokens: Vec<Token>,
|
||||
top_tokens: Vec<Vec<Token>>,
|
||||
},
|
||||
// Last message
|
||||
End {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
tokens: Vec<Token>,
|
||||
top_tokens: Vec<Vec<Token>>,
|
||||
generated_text: GeneratedText,
|
||||
start: Instant,
|
||||
queued: Instant,
|
||||
|
@ -279,9 +279,10 @@ pub(crate) struct StreamDetails {
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub token: Token,
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Token>,
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, default = "null", example = "test")]
|
||||
pub generated_text: Option<String>,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
|
@ -388,14 +388,15 @@ async fn generate_stream(
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Intermediate{
|
||||
token,
|
||||
tokens,
|
||||
top_tokens,
|
||||
} => {
|
||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||
tracing::debug!(parent: &span, "Tokens: {:?}", tokens);
|
||||
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
tokens,
|
||||
text,
|
||||
top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
@ -405,7 +406,8 @@ async fn generate_stream(
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
tokens,
|
||||
text,
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
@ -457,8 +459,9 @@ async fn generate_stream(
|
||||
tracing::info!(parent: &span, "Success");
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
tokens,
|
||||
top_tokens,
|
||||
text
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
|
@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Tokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
@ -676,8 +675,8 @@ class CausalLM(Model):
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
@ -691,7 +690,7 @@ class CausalLM(Model):
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
@ -703,10 +702,12 @@ class CausalLM(Model):
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
@ -14,10 +14,9 @@ from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Tokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
@ -952,14 +951,17 @@ class FlashCausalLM(Model):
|
||||
# Append next token to all tokens
|
||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||
all_input_ids.extend(_next_token_ids)
|
||||
|
||||
next_token_texts = []
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
all_input_ids.append(next_token_ids[j])
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
next_token_texts.append(next_token_text)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
|
||||
@ -1013,8 +1015,8 @@ class FlashCausalLM(Model):
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
@ -1028,7 +1030,7 @@ class FlashCausalLM(Model):
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
@ -1037,16 +1039,15 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
next_token_ids = _next_token_ids[0]
|
||||
next_token_logprob = _next_token_logprobs[0]
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
Tokens(
|
||||
_next_token_ids,
|
||||
_next_token_logprobs,
|
||||
next_token_texts,
|
||||
[nid in self.all_special_ids for nid in _next_token_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Tokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
@ -791,8 +791,8 @@ class IdeficsCausalLM(Model):
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
@ -802,10 +802,12 @@ class IdeficsCausalLM(Model):
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
@ -11,8 +11,7 @@ from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
Batch,
|
||||
Generation,
|
||||
PrefillTokens,
|
||||
TopTokens,
|
||||
Tokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
@ -733,10 +732,11 @@ class Seq2SeqLM(Model):
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_tokens = Tokens(
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
[self.tokenizer.bos_token],
|
||||
[false]
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
@ -750,7 +750,7 @@ class Seq2SeqLM(Model):
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_tokens = Tokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
@ -762,10 +762,12 @@ class Seq2SeqLM(Model):
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
Tokens(
|
||||
[next_token_id_squeezed],
|
||||
[next_token_logprob],
|
||||
[next_token_text],
|
||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||
),
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
@ -58,33 +58,15 @@ class GeneratedText:
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillTokens:
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
|
||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||
return generate_pb2.PrefillTokens(
|
||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopTokens:
|
||||
class Tokens:
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
is_special: List[bool]
|
||||
|
||||
def to_pb(self) -> generate_pb2.TopTokens:
|
||||
return generate_pb2.TopTokens(
|
||||
ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
texts=self.texts,
|
||||
is_special=self.is_special,
|
||||
def to_pb(self) -> generate_pb2.Tokens:
|
||||
return generate_pb2.Tokens(
|
||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
@ -94,14 +76,11 @@ class TopTokens:
|
||||
@dataclass
|
||||
class Generation:
|
||||
request_id: int
|
||||
prefill_tokens: Optional[PrefillTokens]
|
||||
token_id: int
|
||||
token_logprob: float
|
||||
token_text: str
|
||||
token_is_special: bool
|
||||
prefill_tokens: Optional[Tokens]
|
||||
tokens: Tokens
|
||||
generated_text: Optional[GeneratedText]
|
||||
# Optional for now, since it's not yet supported for every model.
|
||||
top_tokens: Optional[TopTokens]
|
||||
top_tokens: Optional[List[Tokens]]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
@ -109,10 +88,7 @@ class Generation:
|
||||
prefill_tokens=self.prefill_tokens.to_pb()
|
||||
if self.prefill_tokens is not None
|
||||
else None,
|
||||
token_id=self.token_id,
|
||||
token_logprob=self.token_logprob,
|
||||
token_text=self.token_text,
|
||||
token_is_special=self.token_is_special,
|
||||
tokens=self.tokens.to_pb(),
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
|
Loading…
Reference in New Issue
Block a user