This commit is contained in:
OlivierDehaene 2023-01-27 19:52:14 +01:00
parent d917ae8955
commit 8c2ddfe838
4 changed files with 65 additions and 56 deletions

View File

@ -1,6 +1,6 @@
gen-server:
# Compile protos
#pip install grpcio-tools==1.49.1 --no-cache-dir
pip install grpcio-tools==1.49.1 --no-cache-dir
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;

View File

@ -47,10 +47,10 @@ class CausalLMBatch(Batch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
@ -145,8 +145,8 @@ class CausalLMBatch(Batch):
# We need to slice the attention mask to remove padding from previous steps
attention_mask[
start_index:end_index, -batch.max_sequence_length:
] = batch.attention_mask[:, -batch.max_sequence_length:]
start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length :]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
@ -192,22 +192,22 @@ class CausalLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last:
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_sequence_length - 1):,
:,
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :]
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1):,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index += batch.size
@ -271,7 +271,7 @@ class CausalLM(Model):
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
@ -284,7 +284,7 @@ class CausalLM(Model):
return outputs.logits, outputs.past_key_values
def generate_token(
self, batch: CausalLMBatch
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
@ -325,12 +325,12 @@ class CausalLM(Model):
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits)
@ -354,9 +354,12 @@ class CausalLM(Model):
if stop:
# Decode generated tokens
generated_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens:, 0]
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
output_text = request.inputs + generated_text
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -380,21 +383,29 @@ class CausalLM(Model):
# Prefill
if stopping_criteria.current_tokens == 0:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[
-new_input_length:-1]).squeeze(
1).tolist()
prefill_logprobs = [float("nan")] + logprobs[
-new_input_length:-1
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False)
prefill_tokens = PrefillTokens(prefill_token_ids,
prefill_logprobs,
prefill_texts)
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text,
generated_text)
result = Generation(
request.id,
prefill_tokens,
next_token_id_squeezed,
next_token_logprob,
next_token_text,
generated_text,
)
results.append(result)

View File

@ -17,10 +17,10 @@ class Batch(ABC):
@classmethod
@abstractmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Batch":
raise NotImplementedError
@ -41,8 +41,8 @@ class GeneratedText:
return generate_pb2.GeneratedText(
text=self.text,
generated_tokens=self.generated_tokens,
finish_reason=self.finish_reason
seed=self.seed
finish_reason=self.finish_reason,
seed=self.seed,
)
@ -54,9 +54,7 @@ class PrefillTokens:
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
)
@ -72,9 +70,13 @@ class Generation:
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None,
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
)

View File

@ -36,9 +36,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
generations=[
generation.to_pb() for generation in generations
],
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
)
@ -62,9 +60,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(
generations=[
generation.to_pb() for generation in generations
],
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
)