mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Address comments.
This commit is contained in:
parent
e808222dbf
commit
be481a4799
@ -1,6 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package generate.v1;
|
||||
package generate.v2;
|
||||
|
||||
service TextGenerationService {
|
||||
/// Model Info
|
||||
@ -32,6 +32,7 @@ message InfoResponse {
|
||||
string dtype = 2;
|
||||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
optional uint32 speculate = 5;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
|
@ -1,6 +1,6 @@
|
||||
/// Single shard Client
|
||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v1::*;
|
||||
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v2::*;
|
||||
use crate::Result;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use std::cmp::min;
|
||||
|
@ -6,9 +6,9 @@ mod pb;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::HealthResponse;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||
Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
|
@ -167,21 +167,21 @@ impl Infer {
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Intermediate { tokens, top_tokens } => {
|
||||
result_tokens.extend(tokens);
|
||||
result_top_tokens.extend(top_tokens);
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
}
|
||||
// Final message
|
||||
// Set return values
|
||||
InferStreamResponse::End {
|
||||
tokens,
|
||||
token,
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
result_tokens.extend(tokens);
|
||||
result_top_tokens.extend(top_tokens);
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
result_generated_text = Some(generated_text);
|
||||
result_start = Some(start);
|
||||
result_queued = Some(queued)
|
||||
@ -515,7 +515,6 @@ fn send_responses(
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
tracing::info!("Generation: {:?}", generation);
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Send message
|
||||
entry
|
||||
@ -524,68 +523,63 @@ fn send_responses(
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens {
|
||||
tokens_
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!(
|
||||
"tgi_request_skipped_tokens",
|
||||
(n - 1) as f64
|
||||
);
|
||||
for (i, (((id, logprob), text), special)) in tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs.into_iter())
|
||||
.zip(tokens_.texts.into_iter())
|
||||
.zip(tokens_.is_special.into_iter())
|
||||
.map(|(((id, logprob), text), special)| Token {
|
||||
.zip(tokens_.is_special.into_iter()).enumerate() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// generation.top_tokens
|
||||
|
||||
let mut top_tokens = Vec::new();
|
||||
for top_tokens_ in generation.top_tokens {
|
||||
let mut local_top_tokens = Vec::new();
|
||||
local_top_tokens.extend(
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i){
|
||||
top_tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(top_tokens_.logprobs.into_iter())
|
||||
.zip(top_tokens_.texts.into_iter())
|
||||
.zip(top_tokens_.is_special.into_iter())
|
||||
.map(|(((id, logprob), text), special)| Token {
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
}),
|
||||
);
|
||||
top_tokens.push(local_top_tokens);
|
||||
}
|
||||
// Force top_tokens to be the same size as tokens, both are going to be
|
||||
// zipped later
|
||||
if top_tokens.len() != tokens.len() {
|
||||
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
|
||||
}).collect()
|
||||
}else{
|
||||
vec![]
|
||||
|
||||
};
|
||||
match (&generation.generated_text, i){
|
||||
(Some(generated_text), i) if i == n - 1 => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: generated_text.clone(),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
tokens,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
} else {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
|
||||
}
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
@ -613,13 +607,13 @@ pub(crate) enum InferStreamResponse {
|
||||
Prefill(Tokens),
|
||||
// Intermediate messages
|
||||
Intermediate {
|
||||
tokens: Vec<Token>,
|
||||
top_tokens: Vec<Vec<Token>>,
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
},
|
||||
// Last message
|
||||
End {
|
||||
tokens: Vec<Token>,
|
||||
top_tokens: Vec<Vec<Token>>,
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
generated_text: GeneratedText,
|
||||
start: Instant,
|
||||
queued: Instant,
|
||||
|
@ -388,31 +388,39 @@ async fn generate_stream(
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Intermediate{
|
||||
tokens,
|
||||
token,
|
||||
top_tokens,
|
||||
} => {
|
||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||
|
||||
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||
}
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
InferStreamResponse::End {
|
||||
tokens,
|
||||
token,
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => Some(StreamDetails {
|
||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
|
||||
// Timings
|
||||
let total_time = start_time.elapsed();
|
||||
let validation_time = queued - start_time;
|
||||
@ -440,45 +448,22 @@ async fn generate_stream(
|
||||
// StreamResponse
|
||||
end_reached = true;
|
||||
|
||||
let n_tokens = tokens.len();
|
||||
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() {
|
||||
// StreamResponse
|
||||
let stream_token = if i < n_tokens - 1 {
|
||||
StreamResponse {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
}
|
||||
|
||||
}else{
|
||||
let details = match details {
|
||||
true => Some(StreamDetails {
|
||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
let output_text = if let Some(prompt) = &add_prompt {
|
||||
prompt.to_owned() + &generated_text.text
|
||||
}else{
|
||||
generated_text.text.to_owned()
|
||||
};
|
||||
|
||||
tracing::debug!(parent: &span, "Output: {}", output_text);
|
||||
tracing::info!(parent: &span, "Success");
|
||||
|
||||
StreamResponse {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
}
|
||||
};
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||
let mut output_text = generated_text.text;
|
||||
if let Some(prompt) = add_prompt {
|
||||
output_text = prompt + &output_text;
|
||||
}
|
||||
|
||||
tracing::debug!(parent: &span, "Output: {}", output_text);
|
||||
tracing::info!(parent: &span, "Success");
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
|
||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
@ -77,9 +78,6 @@ except ImportError as e:
|
||||
if MISTRAL:
|
||||
__all__.append(FlashMistral)
|
||||
|
||||
SPECULATE = None
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
@ -89,7 +87,6 @@ def get_model(
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
global SPECULATE
|
||||
if dtype is None:
|
||||
# Keep it as default for now and let
|
||||
# every model resolve their own default dtype.
|
||||
@ -101,7 +98,10 @@ def get_model(
|
||||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
SPECULATE = 2
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
set_speculate(2)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
@ -144,18 +144,22 @@ def get_model(
|
||||
medusa_config = config_dict
|
||||
model_id = config_dict["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
SPECULATE = config_dict["medusa_num_heads"]
|
||||
speculate_medusa = config_dict["medusa_num_heads"]
|
||||
if speculate is not None:
|
||||
if speculate > speculate_medusa:
|
||||
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
|
||||
else:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
set_speculate(speculate_medusa)
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
method = "medusa"
|
||||
else:
|
||||
if speculate is not None:
|
||||
SPECULATE = speculate
|
||||
else:
|
||||
SPECULATE = 2
|
||||
method = "n-gram"
|
||||
logger.info(f"Using speculation {method} with {SPECULATE} input ids.")
|
||||
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
|
||||
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
|
@ -11,7 +11,8 @@ from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
@ -192,8 +193,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
from text_generation_server.models import SPECULATE
|
||||
speculative_length = SPECULATE
|
||||
speculative_length = get_speculate()
|
||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
blocks += needed_blocks
|
||||
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
blocks += b.blocks
|
||||
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1]
|
||||
speculative_length = b.speculative_ids.shape[1]
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||
max_length = max(
|
||||
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
|
||||
device=batches[0].next_token_chooser.device,
|
||||
)
|
||||
|
||||
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
for b in batches:
|
||||
@ -825,16 +825,15 @@ class FlashCausalLM(Model):
|
||||
next_token_logits = out
|
||||
|
||||
|
||||
from text_generation_server.models import SPECULATE
|
||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
|
||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
||||
speculative_length = speculative_ids.shape[1]
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
@ -1038,6 +1037,7 @@ class FlashCausalLM(Model):
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_tokens = Tokens(
|
||||
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
|
||||
)
|
||||
|
12
server/text_generation_server/utils/speculate.py
Normal file
12
server/text_generation_server/utils/speculate.py
Normal file
@ -0,0 +1,12 @@
|
||||
|
||||
SPECULATE = None
|
||||
|
||||
def get_speculate():
|
||||
global SPECULATE
|
||||
return SPECULATE
|
||||
|
||||
def set_speculate(speculate: int):
|
||||
global SPECULATE
|
||||
SPECULATE = speculate
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user