This commit is contained in:
OlivierDehaene 2022-12-15 16:50:11 +01:00
parent 8420ee8fa2
commit c90afea3de
5 changed files with 57 additions and 46 deletions

View File

@ -64,14 +64,14 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
/// Generate method /// Generate method
#[instrument( #[instrument(
skip(state), skip(state),
fields( fields(
total_time, total_time,
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token time_per_token
) )
)] )]
async fn generate( async fn generate(
state: Extension<ServerState>, state: Extension<ServerState>,
@ -123,7 +123,7 @@ async fn generate(
tokens, tokens,
}) })
} }
false => None false => None,
}; };
// Timings // Timings
@ -164,7 +164,6 @@ async fn generate(
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::info!("Output: {}", response.output_text); tracing::info!("Output: {}", response.output_text);
// Send response // Send response
let response = vec![GeneratedText { let response = vec![GeneratedText {
generated_text: response.output_text, generated_text: response.output_text,
@ -219,7 +218,7 @@ async fn shutdown_signal() {
}; };
#[cfg(unix)] #[cfg(unix)]
let terminate = async { let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate()) signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler") .expect("failed to install signal handler")
.recv() .recv()
@ -227,7 +226,7 @@ async fn shutdown_signal() {
}; };
#[cfg(not(unix))] #[cfg(not(unix))]
let terminate = std::future::pending::<()>(); let terminate = std::future::pending::<()>();
tokio::select! { tokio::select! {
_ = ctrl_c => {}, _ = ctrl_c => {},

View File

@ -128,7 +128,9 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generated_texts[0].request == default_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens
@ -170,7 +172,9 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens
@ -259,7 +263,9 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_bloom_batch.requests[0] assert generated_texts[0].request == default_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens
@ -279,7 +285,9 @@ def test_batch_concatenate(
assert next_batch is None assert next_batch is None
assert len(generated_texts) == 1 assert len(generated_texts) == 1
assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert (
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
)
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert ( assert (
generated_texts[0].generated_tokens generated_texts[0].generated_tokens

View File

@ -47,7 +47,7 @@ class CausalLMBatch:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -148,8 +148,8 @@ class CausalLMBatch:
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
attention_mask[ attention_mask[
start_index:end_index, -batch.max_sequence_length: start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length:] ] = batch.attention_mask[:, -batch.max_sequence_length :]
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past past_keys, past_values = past
@ -197,22 +197,22 @@ class CausalLMBatch:
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_sequence_length - 1):, -(batch.max_sequence_length - 1) :,
:, :,
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
else: else:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_sequence_length - 1):, -(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1):, : start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1):, :] ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index += batch.size start_index += batch.size
@ -268,7 +268,7 @@ class CausalLM(Model):
return CausalLMBatch return CausalLMBatch
def forward( def forward(
self, input_ids, attention_mask, past_key_values: Optional = None self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs = self.model.forward(
@ -280,7 +280,7 @@ class CausalLM(Model):
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
@ -320,13 +320,13 @@ class CausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length, input_length,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
all_logprobs, all_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits) tokens, logprobs = next_token_chooser(all_input_ids, logits)
@ -355,7 +355,9 @@ class CausalLM(Model):
token_ids = all_input_ids[-new_input_length:] token_ids = all_input_ids[-new_input_length:]
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
# Add NaN for the first prompt token # Add NaN for the first prompt token
logprobs = [float('nan')] + all_logprobs[-new_input_length:].squeeze(1).tolist() logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze(
1
).tolist()
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
@ -366,7 +368,7 @@ class CausalLM(Model):
tokens=tokens, tokens=tokens,
token_ids=token_ids.squeeze(1).tolist(), token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason reason=reason,
) )
) )
# add to the next batch # add to the next batch

View File

@ -450,7 +450,9 @@ class Seq2SeqLM(Model):
tokens = self.tokenizer.batch_decode(token_ids) tokens = self.tokenizer.batch_decode(token_ids)
print(tokens) print(tokens)
# Add NaN for the bos token # Add NaN for the bos token
logprobs = [float('nan')] + decoder_logprobs[-new_decoder_input_length:].tolist() logprobs = [float("nan")] + decoder_logprobs[
-new_decoder_input_length:
].tolist()
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText( GeneratedText(
@ -460,7 +462,7 @@ class Seq2SeqLM(Model):
tokens=tokens, tokens=tokens,
token_ids=token_ids.tolist(), token_ids=token_ids.tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason reason=reason,
) )
) )
# add to the next batch # add to the next batch

View File

@ -99,7 +99,7 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
): ):
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
@ -119,7 +119,7 @@ class StoppingCriteria:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [] stop_sequence_criterias = []
for stop_sequence in pb.stop_sequences: for stop_sequence in pb.stop_sequences: