diff --git a/router/src/server.rs b/router/src/server.rs index 0ded187d..2e6c473f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -64,14 +64,14 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json, @@ -123,7 +123,7 @@ async fn generate( tokens, }) } - false => None + false => None, }; // Timings @@ -164,7 +164,6 @@ async fn generate( tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); tracing::info!("Output: {}", response.output_text); - // Send response let response = vec![GeneratedText { generated_text: response.output_text, @@ -219,7 +218,7 @@ async fn shutdown_signal() { }; #[cfg(unix)] - let terminate = async { + let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() @@ -227,7 +226,7 @@ async fn shutdown_signal() { }; #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index ee16b95f..2a6e670e 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -128,7 +128,9 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( generated_texts[0].generated_tokens @@ -170,7 +172,9 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( generated_texts[0].generated_tokens @@ -259,7 +263,9 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( generated_texts[0].generated_tokens @@ -279,7 +285,9 @@ def test_batch_concatenate( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + assert ( + generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + ) assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( generated_texts[0].generated_tokens diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 7fa027a5..5d62137f 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -47,7 +47,7 @@ class CausalLMBatch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -148,8 +148,8 @@ class CausalLMBatch: # We need to slice the attention mask to remove padding from previous steps attention_mask[ - start_index:end_index, -batch.max_sequence_length: - ] = batch.attention_mask[:, -batch.max_sequence_length:] + start_index:end_index, -batch.max_sequence_length : + ] = batch.attention_mask[:, -batch.max_sequence_length :] for j, past in enumerate(batch.past_key_values): past_keys, past_values = past @@ -197,22 +197,22 @@ class CausalLMBatch: # We slice the past keys and values to remove the padding from previous batches if batch.keys_head_dim_last: past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1):, - :, - ] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] + start_index:end_index, + :, + -(batch.max_sequence_length - 1) :, + :, + ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] else: past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1):, - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1) :, + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1):, : - ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] + start_index:end_index, :, -(batch.max_sequence_length - 1) :, : + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] start_index += batch.size @@ -268,7 +268,7 @@ class CausalLM(Model): return CausalLMBatch def forward( - self, input_ids, attention_mask, past_key_values: Optional = None + self, input_ids, attention_mask, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -280,7 +280,7 @@ class CausalLM(Model): return outputs.logits, outputs.past_key_values def generate_token( - self, batch: CausalLMBatch + self, batch: CausalLMBatch ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( @@ -320,13 +320,13 @@ class CausalLM(Model): # For each member of the batch for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - all_logprobs, + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + all_logprobs, ) in enumerate(iterator): # Select next token tokens, logprobs = next_token_chooser(all_input_ids, logits) @@ -355,7 +355,9 @@ class CausalLM(Model): token_ids = all_input_ids[-new_input_length:] tokens = self.tokenizer.batch_decode(token_ids) # Add NaN for the first prompt token - logprobs = [float('nan')] + all_logprobs[-new_input_length:].squeeze(1).tolist() + logprobs = [float("nan")] + all_logprobs[-new_input_length:].squeeze( + 1 + ).tolist() # Add to the list of finished generations with the original request generated_texts.append( @@ -366,7 +368,7 @@ class CausalLM(Model): tokens=tokens, token_ids=token_ids.squeeze(1).tolist(), logprobs=logprobs, - reason=reason + reason=reason, ) ) # add to the next batch diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 35c18518..065ecbff 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -450,7 +450,9 @@ class Seq2SeqLM(Model): tokens = self.tokenizer.batch_decode(token_ids) print(tokens) # Add NaN for the bos token - logprobs = [float('nan')] + decoder_logprobs[-new_decoder_input_length:].tolist() + logprobs = [float("nan")] + decoder_logprobs[ + -new_decoder_input_length: + ].tolist() # Add to the list of finished generations with the original request generated_texts.append( GeneratedText( @@ -460,7 +462,7 @@ class Seq2SeqLM(Model): tokens=tokens, token_ids=token_ids.tolist(), logprobs=logprobs, - reason=reason + reason=reason, ) ) # add to the next batch diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 8ecf4be0..6363321a 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -99,7 +99,7 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 + self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 ): self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens @@ -119,7 +119,7 @@ class StoppingCriteria: @classmethod def from_pb( - cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer + cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer ) -> "StoppingCriteria": stop_sequence_criterias = [] for stop_sequence in pb.stop_sequences: