mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
black
This commit is contained in:
parent
d917ae8955
commit
8c2ddfe838
@ -1,6 +1,6 @@
|
||||
gen-server:
|
||||
# Compile protos
|
||||
#pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||
pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||
mkdir text_generation/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
|
@ -145,8 +145,8 @@ class CausalLMBatch(Batch):
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
attention_mask[
|
||||
start_index:end_index, -batch.max_sequence_length:
|
||||
] = batch.attention_mask[:, -batch.max_sequence_length:]
|
||||
start_index:end_index, -batch.max_sequence_length :
|
||||
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
||||
|
||||
# Create empty tensor
|
||||
# position_ids is always of shape [batch_size, 1]
|
||||
@ -194,20 +194,20 @@ class CausalLMBatch(Batch):
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1):,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
:,
|
||||
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :]
|
||||
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
else:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1):,
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
past_key_values[j][1][
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
@ -354,9 +354,12 @@ class CausalLM(Model):
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
generated_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens:, 0]
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
output_text = request.inputs + generated_text
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason
|
||||
)
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
@ -380,21 +383,29 @@ class CausalLM(Model):
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 0:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[
|
||||
-new_input_length:-1]).squeeze(
|
||||
1).tolist()
|
||||
prefill_logprobs = [float("nan")] + logprobs[
|
||||
-new_input_length:-1
|
||||
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(prefill_token_ids,
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False)
|
||||
prefill_tokens = PrefillTokens(prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts)
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text,
|
||||
generated_text)
|
||||
result = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
|
||||
|
@ -41,8 +41,8 @@ class GeneratedText:
|
||||
return generate_pb2.GeneratedText(
|
||||
text=self.text,
|
||||
generated_tokens=self.generated_tokens,
|
||||
finish_reason=self.finish_reason
|
||||
seed=self.seed
|
||||
finish_reason=self.finish_reason,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
|
||||
@ -54,9 +54,7 @@ class PrefillTokens:
|
||||
|
||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||
return generate_pb2.PrefillTokens(
|
||||
ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
texts=self.texts
|
||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||
)
|
||||
|
||||
|
||||
@ -72,9 +70,13 @@ class Generation:
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
request_id=self.request_id,
|
||||
prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None,
|
||||
prefill_tokens=self.prefill_tokens.to_pb()
|
||||
if self.prefill_tokens is not None
|
||||
else None,
|
||||
token_id=self.token_id,
|
||||
token_logprob=self.token_logprob,
|
||||
token_text=self.token_text,
|
||||
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
)
|
||||
|
@ -36,9 +36,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[
|
||||
generation.to_pb() for generation in generations
|
||||
],
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
@ -62,9 +60,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[
|
||||
generation.to_pb() for generation in generations
|
||||
],
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user