Address comments.

This commit is contained in:
Nicolas Patry 2023-12-05 15:21:42 +00:00
parent e808222dbf
commit be481a4799
8 changed files with 128 additions and 132 deletions

View File

@ -1,6 +1,6 @@
syntax = "proto3"; syntax = "proto3";
package generate.v1; package generate.v2;
service TextGenerationService { service TextGenerationService {
/// Model Info /// Model Info
@ -32,6 +32,7 @@ message InfoResponse {
string dtype = 2; string dtype = 2;
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
optional uint32 speculate = 5;
} }
/// Empty request /// Empty request

View File

@ -1,6 +1,6 @@
/// Single shard Client /// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*; use crate::pb::generate::v2::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
use std::cmp::min; use std::cmp::min;

View File

@ -6,9 +6,9 @@ mod pb;
mod sharded_client; mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::HealthResponse; pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v1::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens, Request, StoppingCriteriaParameters, Tokens,
}; };

View File

@ -167,21 +167,21 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Intermediate { tokens, top_tokens } => { InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.extend(tokens); result_tokens.push(token);
result_top_tokens.extend(top_tokens); result_top_tokens.push(top_tokens);
} }
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
tokens, token,
generated_text, generated_text,
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
result_tokens.extend(tokens); result_tokens.push(token);
result_top_tokens.extend(top_tokens); result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
@ -515,7 +515,6 @@ fn send_responses(
let mut stopped = false; let mut stopped = false;
tracing::info!("Generation: {:?}", generation);
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
entry entry
@ -524,68 +523,63 @@ fn send_responses(
} }
// Create last Token // Create last Token
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens { let tokens_ = generation.tokens.expect("Non empty tokens in generation");
tokens_ let n = tokens_.ids.len();
metrics::histogram!(
"tgi_request_skipped_tokens",
(n - 1) as f64
);
for (i, (((id, logprob), text), special)) in tokens_
.ids .ids
.into_iter() .into_iter()
.zip(tokens_.logprobs.into_iter()) .zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter()) .zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter()) .zip(tokens_.is_special.into_iter()).enumerate() {
.map(|(((id, logprob), text), special)| Token { let token = Token {
id, id,
text, text,
logprob, logprob,
special, special,
}) };
.collect() let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i){
} else {
vec![]
};
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_tokens_ in generation.top_tokens {
let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
top_tokens_ top_tokens_
.ids .ids
.into_iter() .iter()
.zip(top_tokens_.logprobs.into_iter()) .zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.into_iter()) .zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.into_iter()) .zip(top_tokens_.is_special.iter())
.map(|(((id, logprob), text), special)| Token { .map(|(((&id, &logprob), text), &special)| Token {
id, id,
text, text: text.to_string(),
logprob, logprob,
special, special,
}), }).collect()
); }else{
top_tokens.push(local_top_tokens); vec![]
}
// Force top_tokens to be the same size as tokens, both are going to be };
// zipped later match (&generation.generated_text, i){
if top_tokens.len() != tokens.len() { (Some(generated_text), i) if i == n - 1 => {
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect(); // Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: generated_text.clone(),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
} }
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
tokens,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
}
Ok(stopped) Ok(stopped)
} }
@ -613,13 +607,13 @@ pub(crate) enum InferStreamResponse {
Prefill(Tokens), Prefill(Tokens),
// Intermediate messages // Intermediate messages
Intermediate { Intermediate {
tokens: Vec<Token>, token: Token,
top_tokens: Vec<Vec<Token>>, top_tokens: Vec<Token>,
}, },
// Last message // Last message
End { End {
tokens: Vec<Token>, token: Token,
top_tokens: Vec<Vec<Token>>, top_tokens: Vec<Token>,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,

View File

@ -388,31 +388,39 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Intermediate{ InferStreamResponse::Intermediate{
tokens, token,
top_tokens, top_tokens,
} => { } => {
tracing::debug!(parent: &span, "Token: {:?}", token);
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) { // StreamResponse
// StreamResponse let stream_token = StreamResponse {
let stream_token = StreamResponse { token,
token, top_tokens,
top_tokens, generated_text: None,
generated_text: None, details: None,
details: None, };
};
yield Ok(Event::default().json_data(stream_token).unwrap()); yield Ok(Event::default().json_data(stream_token).unwrap())
}
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
tokens, token,
generated_text, generated_text,
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
// Token details // Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings // Timings
let total_time = start_time.elapsed(); let total_time = start_time.elapsed();
let validation_time = queued - start_time; let validation_time = queued - start_time;
@ -440,45 +448,22 @@ async fn generate_stream(
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
let n_tokens = tokens.len(); let mut output_text = generated_text.text;
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() { if let Some(prompt) = add_prompt {
// StreamResponse output_text = prompt + &output_text;
let stream_token = if i < n_tokens - 1 {
StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
}
}else{
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
let output_text = if let Some(prompt) = &add_prompt {
prompt.to_owned() + &generated_text.text
}else{
generated_text.text.to_owned()
};
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");
StreamResponse {
token,
top_tokens,
generated_text: Some(output_text),
details
}
};
yield Ok(Event::default().json_data(stream_token).unwrap());
} }
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse {
token,
top_tokens,
generated_text: Some(output_text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap());
break; break;
} }
} }

View File

@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from typing import Optional from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
@ -77,9 +78,6 @@ except ImportError as e:
if MISTRAL: if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
SPECULATE = None
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -89,7 +87,6 @@ def get_model(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
global SPECULATE
if dtype is None: if dtype is None:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.
@ -101,7 +98,10 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
SPECULATE = 2 if speculate is not None:
set_speculate(speculate)
else:
set_speculate(2)
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
@ -144,18 +144,22 @@ def get_model(
medusa_config = config_dict medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
SPECULATE = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
method = "medusa" method = "medusa"
else: else:
if speculate is not None:
SPECULATE = speculate
else:
SPECULATE = 2
method = "n-gram" method = "n-gram"
logger.info(f"Using speculation {method} with {SPECULATE} input ids.") logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
model_type = config_dict["model_type"] model_type = config_dict["model_type"]

View File

@ -11,7 +11,8 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -192,8 +193,7 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
from text_generation_server.models import SPECULATE speculative_length = get_speculate()
speculative_length = SPECULATE
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1] speculative_length = b.speculative_ids.shape[1]
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0) speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
@ -825,16 +825,15 @@ class FlashCausalLM(Model):
next_token_logits = out next_token_logits = out
from text_generation_server.models import SPECULATE
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] speculative_length = speculative_ids.shape[1]
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1038,6 +1037,7 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
) )

View File

@ -0,0 +1,12 @@
SPECULATE = None
def get_speculate():
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate