From 48d095733aa2a284c20cae1aec78955e7d7df8d4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 27 Jan 2023 19:52:14 +0100 Subject: [PATCH] black --- server/Makefile | 2 +- server/text_generation/models/causal_lm.py | 87 ++++++++++++---------- server/text_generation/models/types.py | 24 +++--- server/text_generation/server.py | 8 +- 4 files changed, 65 insertions(+), 56 deletions(-) diff --git a/server/Makefile b/server/Makefile index 82fff0db..6961178b 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,6 +1,6 @@ gen-server: # Compile protos - #pip install grpcio-tools==1.49.1 --no-cache-dir + pip install grpcio-tools==1.49.1 --no-cache-dir mkdir text_generation/pb || true python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 402290b2..ce16dfb1 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -47,10 +47,10 @@ class CausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -145,8 +145,8 @@ class CausalLMBatch(Batch): # We need to slice the attention mask to remove padding from previous steps attention_mask[ - start_index:end_index, -batch.max_sequence_length: - ] = batch.attention_mask[:, -batch.max_sequence_length:] + start_index:end_index, -batch.max_sequence_length : + ] = batch.attention_mask[:, -batch.max_sequence_length :] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -192,22 +192,22 @@ class CausalLMBatch(Batch): # We slice the past keys and values to remove the padding from previous batches if batch.keys_head_dim_last: past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1):, - :, - ] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] + start_index:end_index, + :, + -(batch.max_sequence_length - 1) :, + :, + ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] else: past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1):, - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1) :, + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1):, : - ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] + start_index:end_index, :, -(batch.max_sequence_length - 1) :, : + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] start_index += batch.size @@ -271,7 +271,7 @@ class CausalLM(Model): ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -284,7 +284,7 @@ class CausalLM(Model): return outputs.logits, outputs.past_key_values def generate_token( - self, batch: CausalLMBatch + self, batch: CausalLMBatch ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( @@ -325,12 +325,12 @@ class CausalLM(Model): # For each member of the batch for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Select next token tokens, logprobs = next_token_chooser(all_input_ids, logits) @@ -354,9 +354,12 @@ class CausalLM(Model): if stop: # Decode generated tokens generated_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens:, 0] + all_input_ids[-stopping_criteria.current_tokens :, 0] ) output_text = request.inputs + generated_text + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason + ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -380,21 +383,29 @@ class CausalLM(Model): # Prefill if stopping_criteria.current_tokens == 0: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[ - -new_input_length:-1]).squeeze( - 1).tolist() + prefill_logprobs = [float("nan")] + logprobs[ + -new_input_length:-1 + ].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode(prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False) - prefill_tokens = PrefillTokens(prefill_token_ids, - prefill_logprobs, - prefill_texts) + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) else: prefill_tokens = None - result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text, - generated_text) + result = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, + ) results.append(result) diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 0ad8cc87..de21b20b 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -17,10 +17,10 @@ class Batch(ABC): @classmethod @abstractmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "Batch": raise NotImplementedError @@ -41,8 +41,8 @@ class GeneratedText: return generate_pb2.GeneratedText( text=self.text, generated_tokens=self.generated_tokens, - finish_reason=self.finish_reason - seed=self.seed + finish_reason=self.finish_reason, + seed=self.seed, ) @@ -54,9 +54,7 @@ class PrefillTokens: def to_pb(self) -> generate_pb2.PrefillTokens: return generate_pb2.PrefillTokens( - ids=self.token_ids, - logprobs=self.logprobs, - texts=self.texts + ids=self.token_ids, logprobs=self.logprobs, texts=self.texts ) @@ -72,9 +70,13 @@ class Generation: def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( request_id=self.request_id, - prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, + prefill_tokens=self.prefill_tokens.to_pb() + if self.prefill_tokens is not None + else None, token_id=self.token_id, token_logprob=self.token_logprob, token_text=self.token_text, - generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, + generated_text=self.generated_text.to_pb() + if self.generated_text is not None + else None, ) diff --git a/server/text_generation/server.py b/server/text_generation/server.py index 1cf8de95..a2bad8a7 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -36,9 +36,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache.set(next_batch) return generate_pb2.PrefillResponse( - generations=[ - generation.to_pb() for generation in generations - ], + generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, ) @@ -62,9 +60,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache.set(next_batch) return generate_pb2.DecodeResponse( - generations=[ - generation.to_pb() for generation in generations - ], + generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, )