This commit is contained in:
OlivierDehaene 2023-01-27 19:52:14 +01:00
parent d917ae8955
commit 8c2ddfe838
4 changed files with 65 additions and 56 deletions

View File

@ -1,6 +1,6 @@
gen-server: gen-server:
# Compile protos # Compile protos
#pip install grpcio-tools==1.49.1 --no-cache-dir pip install grpcio-tools==1.49.1 --no-cache-dir
mkdir text_generation/pb || true mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;

View File

@ -47,10 +47,10 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -145,8 +145,8 @@ class CausalLMBatch(Batch):
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
attention_mask[ attention_mask[
start_index:end_index, -batch.max_sequence_length: start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length:] ] = batch.attention_mask[:, -batch.max_sequence_length :]
# Create empty tensor # Create empty tensor
# position_ids is always of shape [batch_size, 1] # position_ids is always of shape [batch_size, 1]
@ -192,22 +192,22 @@ class CausalLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_sequence_length - 1):, -(batch.max_sequence_length - 1) :,
:, :,
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
else: else:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_sequence_length - 1):, -(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1):, : start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1):, :] ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index += batch.size start_index += batch.size
@ -271,7 +271,7 @@ class CausalLM(Model):
) )
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs = self.model.forward(
@ -284,7 +284,7 @@ class CausalLM(Model):
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
@ -325,12 +325,12 @@ class CausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length, input_length,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits) tokens, logprobs = next_token_chooser(all_input_ids, logits)
@ -354,9 +354,12 @@ class CausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
generated_text = self.decode( generated_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens:, 0] all_input_ids[-stopping_criteria.current_tokens :, 0]
) )
output_text = request.inputs + generated_text output_text = request.inputs + generated_text
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason
)
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -380,21 +383,29 @@ class CausalLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 0: if stopping_criteria.current_tokens == 0:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[ prefill_logprobs = [float("nan")] + logprobs[
-new_input_length:-1]).squeeze( -new_input_length:-1
1).tolist() ].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(prefill_token_ids, prefill_texts = self.tokenizer.batch_decode(
clean_up_tokenization_spaces=False, prefill_token_ids,
skip_special_tokens=False) clean_up_tokenization_spaces=False,
prefill_tokens = PrefillTokens(prefill_token_ids, skip_special_tokens=False,
prefill_logprobs, )
prefill_texts) prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else: else:
prefill_tokens = None prefill_tokens = None
result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text, result = Generation(
generated_text) request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
results.append(result) results.append(result)

View File

@ -17,10 +17,10 @@ class Batch(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Batch": ) -> "Batch":
raise NotImplementedError raise NotImplementedError
@ -41,8 +41,8 @@ class GeneratedText:
return generate_pb2.GeneratedText( return generate_pb2.GeneratedText(
text=self.text, text=self.text,
generated_tokens=self.generated_tokens, generated_tokens=self.generated_tokens,
finish_reason=self.finish_reason finish_reason=self.finish_reason,
seed=self.seed seed=self.seed,
) )
@ -54,9 +54,7 @@ class PrefillTokens:
def to_pb(self) -> generate_pb2.PrefillTokens: def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens( return generate_pb2.PrefillTokens(
ids=self.token_ids, ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
logprobs=self.logprobs,
texts=self.texts
) )
@ -72,9 +70,13 @@ class Generation:
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
request_id=self.request_id, request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id, token_id=self.token_id,
token_logprob=self.token_logprob, token_logprob=self.token_logprob,
token_text=self.token_text, token_text=self.token_text,
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
) )

View File

@ -36,9 +36,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
generations=[ generations=[generation.to_pb() for generation in generations],
generation.to_pb() for generation in generations
],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
@ -62,9 +60,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[ generations=[generation.to_pb() for generation in generations],
generation.to_pb() for generation in generations
],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )