mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Address comments.
This commit is contained in:
parent
e808222dbf
commit
be481a4799
@ -1,6 +1,6 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package generate.v1;
|
package generate.v2;
|
||||||
|
|
||||||
service TextGenerationService {
|
service TextGenerationService {
|
||||||
/// Model Info
|
/// Model Info
|
||||||
@ -32,6 +32,7 @@ message InfoResponse {
|
|||||||
string dtype = 2;
|
string dtype = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
|
optional uint32 speculate = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v2::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
@ -6,9 +6,9 @@ mod pb;
|
|||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::HealthResponse;
|
pub use pb::generate::v2::HealthResponse;
|
||||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||||
Request, StoppingCriteriaParameters, Tokens,
|
Request, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
|
@ -167,21 +167,21 @@ impl Infer {
|
|||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Intermediate { tokens, top_tokens } => {
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||||
result_tokens.extend(tokens);
|
result_tokens.push(token);
|
||||||
result_top_tokens.extend(top_tokens);
|
result_top_tokens.push(top_tokens);
|
||||||
}
|
}
|
||||||
// Final message
|
// Final message
|
||||||
// Set return values
|
// Set return values
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
tokens,
|
token,
|
||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
result_tokens.extend(tokens);
|
result_tokens.push(token);
|
||||||
result_top_tokens.extend(top_tokens);
|
result_top_tokens.push(top_tokens);
|
||||||
result_generated_text = Some(generated_text);
|
result_generated_text = Some(generated_text);
|
||||||
result_start = Some(start);
|
result_start = Some(start);
|
||||||
result_queued = Some(queued)
|
result_queued = Some(queued)
|
||||||
@ -515,7 +515,6 @@ fn send_responses(
|
|||||||
|
|
||||||
let mut stopped = false;
|
let mut stopped = false;
|
||||||
|
|
||||||
tracing::info!("Generation: {:?}", generation);
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry
|
||||||
@ -524,68 +523,63 @@ fn send_responses(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens {
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
tokens_
|
let n = tokens_.ids.len();
|
||||||
|
metrics::histogram!(
|
||||||
|
"tgi_request_skipped_tokens",
|
||||||
|
(n - 1) as f64
|
||||||
|
);
|
||||||
|
for (i, (((id, logprob), text), special)) in tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens_.logprobs.into_iter())
|
.zip(tokens_.logprobs.into_iter())
|
||||||
.zip(tokens_.texts.into_iter())
|
.zip(tokens_.texts.into_iter())
|
||||||
.zip(tokens_.is_special.into_iter())
|
.zip(tokens_.is_special.into_iter()).enumerate() {
|
||||||
.map(|(((id, logprob), text), special)| Token {
|
let token = Token {
|
||||||
id,
|
id,
|
||||||
text,
|
text,
|
||||||
logprob,
|
logprob,
|
||||||
special,
|
special,
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i){
|
||||||
// generation.top_tokens
|
|
||||||
|
|
||||||
let mut top_tokens = Vec::new();
|
|
||||||
for top_tokens_ in generation.top_tokens {
|
|
||||||
let mut local_top_tokens = Vec::new();
|
|
||||||
local_top_tokens.extend(
|
|
||||||
top_tokens_
|
top_tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.iter()
|
||||||
.zip(top_tokens_.logprobs.into_iter())
|
.zip(top_tokens_.logprobs.iter())
|
||||||
.zip(top_tokens_.texts.into_iter())
|
.zip(top_tokens_.texts.iter())
|
||||||
.zip(top_tokens_.is_special.into_iter())
|
.zip(top_tokens_.is_special.iter())
|
||||||
.map(|(((id, logprob), text), special)| Token {
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
id,
|
id,
|
||||||
text,
|
text: text.to_string(),
|
||||||
logprob,
|
logprob,
|
||||||
special,
|
special,
|
||||||
}),
|
}).collect()
|
||||||
);
|
}else{
|
||||||
top_tokens.push(local_top_tokens);
|
vec![]
|
||||||
}
|
|
||||||
// Force top_tokens to be the same size as tokens, both are going to be
|
|
||||||
// zipped later
|
|
||||||
if top_tokens.len() != tokens.len() {
|
|
||||||
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
};
|
||||||
|
match (&generation.generated_text, i){
|
||||||
|
(Some(generated_text), i) if i == n - 1 => {
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
tokens,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text,
|
generated_text: generated_text.clone(),
|
||||||
queued: entry.queue_time,
|
queued: entry.queue_time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
}))?;
|
}))?;
|
||||||
} else {
|
}
|
||||||
|
_ => {
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(stopped)
|
Ok(stopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -613,13 +607,13 @@ pub(crate) enum InferStreamResponse {
|
|||||||
Prefill(Tokens),
|
Prefill(Tokens),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
Intermediate {
|
Intermediate {
|
||||||
tokens: Vec<Token>,
|
token: Token,
|
||||||
top_tokens: Vec<Vec<Token>>,
|
top_tokens: Vec<Token>,
|
||||||
},
|
},
|
||||||
// Last message
|
// Last message
|
||||||
End {
|
End {
|
||||||
tokens: Vec<Token>,
|
token: Token,
|
||||||
top_tokens: Vec<Vec<Token>>,
|
top_tokens: Vec<Token>,
|
||||||
generated_text: GeneratedText,
|
generated_text: GeneratedText,
|
||||||
start: Instant,
|
start: Instant,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
|
@ -388,11 +388,11 @@ async fn generate_stream(
|
|||||||
InferStreamResponse::Prefill(_) => {}
|
InferStreamResponse::Prefill(_) => {}
|
||||||
// Yield event for every new token
|
// Yield event for every new token
|
||||||
InferStreamResponse::Intermediate{
|
InferStreamResponse::Intermediate{
|
||||||
tokens,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
|
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||||
|
|
||||||
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
@ -401,18 +401,26 @@ async fn generate_stream(
|
|||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// Yield event for last token and compute timings
|
// Yield event for last token and compute timings
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
tokens,
|
token,
|
||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
// Token details
|
// Token details
|
||||||
|
let details = match details {
|
||||||
|
true => Some(StreamDetails {
|
||||||
|
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||||
|
generated_tokens: generated_text.generated_tokens,
|
||||||
|
seed: generated_text.seed,
|
||||||
|
}),
|
||||||
|
false => None,
|
||||||
|
};
|
||||||
|
|
||||||
// Timings
|
// Timings
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
let validation_time = queued - start_time;
|
let validation_time = queued - start_time;
|
||||||
@ -440,45 +448,22 @@ async fn generate_stream(
|
|||||||
// StreamResponse
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
|
|
||||||
let n_tokens = tokens.len();
|
let mut output_text = generated_text.text;
|
||||||
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() {
|
if let Some(prompt) = add_prompt {
|
||||||
// StreamResponse
|
output_text = prompt + &output_text;
|
||||||
let stream_token = if i < n_tokens - 1 {
|
|
||||||
StreamResponse {
|
|
||||||
token,
|
|
||||||
top_tokens,
|
|
||||||
generated_text: None,
|
|
||||||
details: None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}else{
|
|
||||||
let details = match details {
|
|
||||||
true => Some(StreamDetails {
|
|
||||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
|
||||||
generated_tokens: generated_text.generated_tokens,
|
|
||||||
seed: generated_text.seed,
|
|
||||||
}),
|
|
||||||
false => None,
|
|
||||||
};
|
|
||||||
let output_text = if let Some(prompt) = &add_prompt {
|
|
||||||
prompt.to_owned() + &generated_text.text
|
|
||||||
}else{
|
|
||||||
generated_text.text.to_owned()
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing::debug!(parent: &span, "Output: {}", output_text);
|
tracing::debug!(parent: &span, "Output: {}", output_text);
|
||||||
tracing::info!(parent: &span, "Success");
|
tracing::info!(parent: &span, "Success");
|
||||||
|
|
||||||
StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: Some(output_text),
|
generated_text: Some(output_text),
|
||||||
details
|
details
|
||||||
}
|
|
||||||
};
|
};
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
@ -77,9 +78,6 @@ except ImportError as e:
|
|||||||
if MISTRAL:
|
if MISTRAL:
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
|
|
||||||
SPECULATE = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -89,7 +87,6 @@ def get_model(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global SPECULATE
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
@ -101,7 +98,10 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
SPECULATE = 2
|
if speculate is not None:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(2)
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
return GalacticaSharded(
|
return GalacticaSharded(
|
||||||
@ -144,18 +144,22 @@ def get_model(
|
|||||||
medusa_config = config_dict
|
medusa_config = config_dict
|
||||||
model_id = config_dict["base_model_name_or_path"]
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
revision = "main"
|
revision = "main"
|
||||||
SPECULATE = config_dict["medusa_num_heads"]
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_medusa:
|
||||||
|
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_medusa)
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
else:
|
else:
|
||||||
if speculate is not None:
|
|
||||||
SPECULATE = speculate
|
|
||||||
else:
|
|
||||||
SPECULATE = 2
|
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
logger.info(f"Using speculation {method} with {SPECULATE} input ids.")
|
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
Tokens,
|
Tokens,
|
||||||
@ -192,8 +193,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
from text_generation_server.models import SPECULATE
|
speculative_length = get_speculate()
|
||||||
speculative_length = SPECULATE
|
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
blocks += b.blocks
|
||||||
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1]
|
speculative_length = b.speculative_ids.shape[1]
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0)
|
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||||
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
@ -825,16 +825,15 @@ class FlashCausalLM(Model):
|
|||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.models import SPECULATE
|
|
||||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
speculative_length = speculative_ids.shape[1]
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -1038,6 +1037,7 @@ class FlashCausalLM(Model):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
|
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
|
||||||
)
|
)
|
||||||
|
12
server/text_generation_server/utils/speculate.py
Normal file
12
server/text_generation_server/utils/speculate.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
|
||||||
|
SPECULATE = None
|
||||||
|
|
||||||
|
def get_speculate():
|
||||||
|
global SPECULATE
|
||||||
|
return SPECULATE
|
||||||
|
|
||||||
|
def set_speculate(speculate: int):
|
||||||
|
global SPECULATE
|
||||||
|
SPECULATE = speculate
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user