mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
black
This commit is contained in:
parent
d917ae8955
commit
8c2ddfe838
@ -1,6 +1,6 @@
|
|||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
#pip install grpcio-tools==1.49.1 --no-cache-dir
|
pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||||
mkdir text_generation/pb || true
|
mkdir text_generation/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
@ -357,6 +357,9 @@ class CausalLM(Model):
|
|||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
)
|
)
|
||||||
output_text = request.inputs + generated_text
|
output_text = request.inputs + generated_text
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason
|
||||||
|
)
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
@ -380,21 +383,29 @@ class CausalLM(Model):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 0:
|
if stopping_criteria.current_tokens == 0:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[
|
prefill_logprobs = [float("nan")] + logprobs[
|
||||||
-new_input_length:-1]).squeeze(
|
-new_input_length:-1
|
||||||
1).tolist()
|
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(prefill_token_ids,
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False)
|
skip_special_tokens=False,
|
||||||
prefill_tokens = PrefillTokens(prefill_token_ids,
|
)
|
||||||
prefill_logprobs,
|
prefill_tokens = PrefillTokens(
|
||||||
prefill_texts)
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text,
|
result = Generation(
|
||||||
generated_text)
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
@ -41,8 +41,8 @@ class GeneratedText:
|
|||||||
return generate_pb2.GeneratedText(
|
return generate_pb2.GeneratedText(
|
||||||
text=self.text,
|
text=self.text,
|
||||||
generated_tokens=self.generated_tokens,
|
generated_tokens=self.generated_tokens,
|
||||||
finish_reason=self.finish_reason
|
finish_reason=self.finish_reason,
|
||||||
seed=self.seed
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -54,9 +54,7 @@ class PrefillTokens:
|
|||||||
|
|
||||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||||
return generate_pb2.PrefillTokens(
|
return generate_pb2.PrefillTokens(
|
||||||
ids=self.token_ids,
|
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||||
logprobs=self.logprobs,
|
|
||||||
texts=self.texts
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -72,9 +70,13 @@ class Generation:
|
|||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None,
|
prefill_tokens=self.prefill_tokens.to_pb()
|
||||||
|
if self.prefill_tokens is not None
|
||||||
|
else None,
|
||||||
token_id=self.token_id,
|
token_id=self.token_id,
|
||||||
token_logprob=self.token_logprob,
|
token_logprob=self.token_logprob,
|
||||||
token_text=self.token_text,
|
token_text=self.token_text,
|
||||||
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,
|
generated_text=self.generated_text.to_pb()
|
||||||
|
if self.generated_text is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
@ -36,9 +36,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generations=[
|
generations=[generation.to_pb() for generation in generations],
|
||||||
generation.to_pb() for generation in generations
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,9 +60,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[
|
generations=[generation.to_pb() for generation in generations],
|
||||||
generation.to_pb() for generation in generations
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user