Address comments.

This commit is contained in:
Nicolas Patry 2023-12-05 15:21:42 +00:00
parent e808222dbf
commit be481a4799
8 changed files with 128 additions and 132 deletions

View File

@ -1,6 +1,6 @@
syntax = "proto3";
package generate.v1;
package generate.v2;
service TextGenerationService {
/// Model Info
@ -32,6 +32,7 @@ message InfoResponse {
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
optional uint32 speculate = 5;
}
/// Empty request

View File

@ -1,6 +1,6 @@
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v2::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;

View File

@ -6,9 +6,9 @@ mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens,
};

View File

@ -167,21 +167,21 @@ impl Infer {
.collect();
}
// Push last token
InferStreamResponse::Intermediate { tokens, top_tokens } => {
result_tokens.extend(tokens);
result_top_tokens.extend(top_tokens);
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
tokens,
token,
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.extend(tokens);
result_top_tokens.extend(top_tokens);
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
@ -515,7 +515,6 @@ fn send_responses(
let mut stopped = false;
tracing::info!("Generation: {:?}", generation);
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry
@ -524,68 +523,63 @@ fn send_responses(
}
// Create last Token
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens {
tokens_
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!(
"tgi_request_skipped_tokens",
(n - 1) as f64
);
for (i, (((id, logprob), text), special)) in tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
.zip(tokens_.is_special.into_iter()).enumerate() {
let token = Token {
id,
text,
logprob,
special,
})
.collect()
} else {
vec![]
};
// generation.top_tokens
let mut top_tokens = Vec::new();
for top_tokens_ in generation.top_tokens {
let mut local_top_tokens = Vec::new();
local_top_tokens.extend(
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i){
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text,
text: text.to_string(),
logprob,
special,
}),
);
top_tokens.push(local_top_tokens);
}
// Force top_tokens to be the same size as tokens, both are going to be
// zipped later
if top_tokens.len() != tokens.len() {
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
}).collect()
}else{
vec![]
};
match (&generation.generated_text, i){
(Some(generated_text), i) if i == n - 1 => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: generated_text.clone(),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
tokens,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?;
}
Ok(stopped)
}
@ -613,13 +607,13 @@ pub(crate) enum InferStreamResponse {
Prefill(Tokens),
// Intermediate messages
Intermediate {
tokens: Vec<Token>,
top_tokens: Vec<Vec<Token>>,
token: Token,
top_tokens: Vec<Token>,
},
// Last message
End {
tokens: Vec<Token>,
top_tokens: Vec<Vec<Token>>,
token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,

View File

@ -388,31 +388,39 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Intermediate{
tokens,
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token);
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
};
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
};
yield Ok(Event::default().json_data(stream_token).unwrap());
}
yield Ok(Event::default().json_data(stream_token).unwrap())
}
// Yield event for last token and compute timings
InferStreamResponse::End {
tokens,
token,
generated_text,
start,
queued,
top_tokens,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
@ -440,45 +448,22 @@ async fn generate_stream(
// StreamResponse
end_reached = true;
let n_tokens = tokens.len();
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() {
// StreamResponse
let stream_token = if i < n_tokens - 1 {
StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
}
}else{
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
let output_text = if let Some(prompt) = &add_prompt {
prompt.to_owned() + &generated_text.text
}else{
generated_text.text.to_owned()
};
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");
StreamResponse {
token,
top_tokens,
generated_text: Some(output_text),
details
}
};
yield Ok(Event::default().json_data(stream_token).unwrap());
let mut output_text = generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
}
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse {
token,
top_tokens,
generated_text: Some(output_text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap());
break;
}
}

View File

@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from typing import Optional
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
@ -77,9 +78,6 @@ except ImportError as e:
if MISTRAL:
__all__.append(FlashMistral)
SPECULATE = None
def get_model(
model_id: str,
revision: Optional[str],
@ -89,7 +87,6 @@ def get_model(
dtype: Optional[str],
trust_remote_code: bool,
) -> Model:
global SPECULATE
if dtype is None:
# Keep it as default for now and let
# every model resolve their own default dtype.
@ -101,7 +98,10 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
SPECULATE = 2
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(2)
if "facebook/galactica" in model_id:
return GalacticaSharded(
@ -144,18 +144,22 @@ def get_model(
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
SPECULATE = config_dict["medusa_num_heads"]
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
if speculate is not None:
SPECULATE = speculate
else:
SPECULATE = 2
method = "n-gram"
logger.info(f"Using speculation {method} with {SPECULATE} input ids.")
logger.info(f"Using speculation {method} with {get_speculate()} input ids.")
model_type = config_dict["model_type"]

View File

@ -12,6 +12,7 @@ from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
Batch,
Tokens,
@ -192,8 +193,7 @@ class FlashCausalLMBatch(Batch):
# Paged attention
# Remove one as the first token des not have a past
from text_generation_server.models import SPECULATE
speculative_length = SPECULATE
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1]
speculative_length = b.speculative_ids.shape[1]
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device,
)
speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0)
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
@ -825,16 +825,15 @@ class FlashCausalLM(Model):
next_token_logits = out
from text_generation_server.models import SPECULATE
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
speculative_length = speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -1038,6 +1037,7 @@ class FlashCausalLM(Model):
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
)

View File

@ -0,0 +1,12 @@
SPECULATE = None
def get_speculate():
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate