Address comments.

This commit is contained in:
Nicolas Patry 2023-12-05 15:21:42 +00:00
parent e808222dbf
commit be481a4799
8 changed files with 128 additions and 132 deletions

View File

@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package generate.v1; package generate.v2;
service TextGenerationService { service TextGenerationService {
/// Model Info /// Model Info
@ -32,6 +32,7 @@ message InfoResponse {
string dtype = 2; string dtype = 2;
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
optional uint32 speculate = 5;
} }
/// Empty request /// Empty request

View File

@ -1,6 +1,6 @@
/// Single shard Client /// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*; use crate::pb::generate::v2::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
use std::cmp::min; use std::cmp::min;

View File

@ -6,9 +6,9 @@ mod pb;
mod sharded_client; mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::HealthResponse; pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v1::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens, Request, StoppingCriteriaParameters, Tokens,
}; };

View File

@ -167,21 +167,21 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Intermediate { tokens, top_tokens } => { InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.extend(tokens); result_tokens.push(token);
result_top_tokens.extend(top_tokens); result_top_tokens.push(top_tokens);
} }
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
tokens, token,
generated_text, generated_text,
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
result_tokens.extend(tokens); result_tokens.push(token);
result_top_tokens.extend(top_tokens); result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
@ -515,7 +515,6 @@ fn send_responses(
let mut stopped = false; let mut stopped = false;
tracing::info!("Generation: {:?}", generation);
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
entry entry
@ -524,68 +523,63 @@ fn send_responses(
} }
// Create last Token // Create last Token
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens { let tokens_ = generation.tokens.expect("Non empty tokens in generation");
tokens_ let n = tokens_.ids.len();
metrics::histogram!(
"tgi_request_skipped_tokens",
(n - 1) as f64
);
for (i, (((id, logprob), text), special)) in tokens_
.ids .ids
.into_iter() .into_iter()
.zip(tokens_.logprobs.into_iter()) .zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter()) .zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter()) .zip(tokens_.is_special.into_iter()).enumerate() {
.map(|(((id, logprob), text), special)| Token { let token = Token {
id, id,
text, text,
logprob, logprob,
special, special,
})
.collect()
} else {
vec![]
}; };
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i){
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_tokens_ in generation.top_tokens {
let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
top_tokens_ top_tokens_
.ids .ids
.into_iter() .iter()
.zip(top_tokens_.logprobs.into_iter()) .zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.into_iter()) .zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.into_iter()) .zip(top_tokens_.is_special.iter())
.map(|(((id, logprob), text), special)| Token { .map(|(((&id, &logprob), text), &special)| Token {
id, id,
text, text: text.to_string(),
logprob, logprob,
special, special,
}), }).collect()
); }else{
top_tokens.push(local_top_tokens); vec![]
}
// Force top_tokens to be the same size as tokens, both are going to be
// zipped later
if top_tokens.len() != tokens.len() {
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
}
if let Some(generated_text) = generation.generated_text { };
match (&generation.generated_text, i){
(Some(generated_text), i) if i == n - 1 => {
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message
entry.response_tx.send(Ok(InferStreamResponse::End { entry.response_tx.send(Ok(InferStreamResponse::End {
tokens, token,
top_tokens, top_tokens,
generated_text, generated_text: generated_text.clone(),
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
}))?; }))?;
} else { }
_ => {
// Send message // Send message
entry entry
.response_tx .response_tx
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?; .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
} }
}
}
Ok(stopped) Ok(stopped)
} }
@ -613,13 +607,13 @@ pub(crate) enum InferStreamResponse {
Prefill(Tokens), Prefill(Tokens),
// Intermediate messages // Intermediate messages
Intermediate { Intermediate {
tokens: Vec<Token>, token: Token,
top_tokens: Vec<Vec<Token>>, top_tokens: Vec<Token>,
}, },
// Last message // Last message
End { End {
tokens: Vec<Token>, token: Token,
top_tokens: Vec<Vec<Token>>, top_tokens: Vec<Token>,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,

View File

@ -388,11 +388,11 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Intermediate{ InferStreamResponse::Intermediate{
tokens, token,
top_tokens, top_tokens,
} => { } => {
tracing::debug!(parent: &span, "Token: {:?}", token);
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
@ -401,18 +401,26 @@ async fn generate_stream(
details: None, details: None,
}; };
yield Ok(Event::default().json_data(stream_token).unwrap()); yield Ok(Event::default().json_data(stream_token).unwrap())
}
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
tokens, token,
generated_text, generated_text,
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
// Token details // Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings // Timings
let total_time = start_time.elapsed(); let total_time = start_time.elapsed();
let validation_time = queued - start_time; let validation_time = queued - start_time;
@ -440,45 +448,22 @@ async fn generate_stream(
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
let n_tokens = tokens.len(); let mut output_text = generated_text.text;
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() { if let Some(prompt) = add_prompt {
// StreamResponse output_text = prompt + &output_text;
let stream_token = if i < n_tokens - 1 {
StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
} }
}else{
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
let output_text = if let Some(prompt) = &add_prompt {
prompt.to_owned() + &generated_text.text
}else{
generated_text.text.to_owned()
};
tracing::debug!(parent: &span, "Output: {}", output_text); tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success"); tracing::info!(parent: &span, "Success");
StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens, top_tokens,
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}
}; };
yield Ok(Event::default().json_data(stream_token).unwrap());
}
yield Ok(Event::default().json_data(stream_token).unwrap());
break; break;
} }
} }

View File

@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from typing import Optional from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
@ -77,9 +78,6 @@ except ImportError as e:
if MISTRAL: if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
SPECULATE = None
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -89,7 +87,6 @@ def get_model(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
global SPECULATE
if dtype is None: if dtype is None:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.
@ -101,7 +98,10 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
SPECULATE = 2 if speculate is not None:
set_speculate(speculate)
else:
set_speculate(2)
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
@ -144,18 +144,22 @@ def get_model(
medusa_config = config_dict medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
SPECULATE = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
method = "medusa" method = "medusa"
else: else:
if speculate is not None:
SPECULATE = speculate
else:
SPECULATE = 2
method = "n-gram" method = "n-gram"
logger.info(f"Using speculation {method} with {SPECULATE} input ids.") logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
model_type = config_dict["model_type"] model_type = config_dict["model_type"]

View File

@ -12,6 +12,7 @@ from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -192,8 +193,7 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
from text_generation_server.models import SPECULATE speculative_length = get_speculate()
speculative_length = SPECULATE
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1] speculative_length = b.speculative_ids.shape[1]
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0) speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
@ -825,16 +825,15 @@ class FlashCausalLM(Model):
next_token_logits = out next_token_logits = out
from text_generation_server.models import SPECULATE
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] speculative_length = speculative_ids.shape[1]
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1038,6 +1037,7 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
) )

View File

@ -0,0 +1,12 @@
SPECULATE = None
def get_speculate():
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate