mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
black
This commit is contained in:
parent
d917ae8955
commit
8c2ddfe838
@ -1,6 +1,6 @@
|
|||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
#pip install grpcio-tools==1.49.1 --no-cache-dir
|
pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||||
mkdir text_generation/pb || true
|
mkdir text_generation/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
@ -47,10 +47,10 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -145,8 +145,8 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_sequence_length:
|
start_index:end_index, -batch.max_sequence_length :
|
||||||
] = batch.attention_mask[:, -batch.max_sequence_length:]
|
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
||||||
|
|
||||||
# Create empty tensor
|
# Create empty tensor
|
||||||
# position_ids is always of shape [batch_size, 1]
|
# position_ids is always of shape [batch_size, 1]
|
||||||
@ -192,22 +192,22 @@ class CausalLMBatch(Batch):
|
|||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1):,
|
-(batch.max_sequence_length - 1) :,
|
||||||
:,
|
:,
|
||||||
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :]
|
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
else:
|
else:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1):,
|
-(batch.max_sequence_length - 1) :,
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||||
|
|
||||||
past_key_values[j][1][
|
past_key_values[j][1][
|
||||||
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
|
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
|
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
@ -284,7 +284,7 @@ class CausalLM(Model):
|
|||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
context_manager = (
|
context_manager = (
|
||||||
@ -325,12 +325,12 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||||
@ -354,9 +354,12 @@ class CausalLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
generated_text = self.decode(
|
generated_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens:, 0]
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
)
|
)
|
||||||
output_text = request.inputs + generated_text
|
output_text = request.inputs + generated_text
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason
|
||||||
|
)
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
@ -380,21 +383,29 @@ class CausalLM(Model):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 0:
|
if stopping_criteria.current_tokens == 0:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[
|
prefill_logprobs = [float("nan")] + logprobs[
|
||||||
-new_input_length:-1]).squeeze(
|
-new_input_length:-1
|
||||||
1).tolist()
|
].gather(1, all_input_ids[-new_input_length:-1]).squeeze(1).tolist()
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(prefill_token_ids,
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
clean_up_tokenization_spaces=False,
|
prefill_token_ids,
|
||||||
skip_special_tokens=False)
|
clean_up_tokenization_spaces=False,
|
||||||
prefill_tokens = PrefillTokens(prefill_token_ids,
|
skip_special_tokens=False,
|
||||||
prefill_logprobs,
|
)
|
||||||
prefill_texts)
|
prefill_tokens = PrefillTokens(
|
||||||
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text,
|
result = Generation(
|
||||||
generated_text)
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_logprob,
|
||||||
|
next_token_text,
|
||||||
|
generated_text,
|
||||||
|
)
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
@ -17,10 +17,10 @@ class Batch(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Batch":
|
) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -41,8 +41,8 @@ class GeneratedText:
|
|||||||
return generate_pb2.GeneratedText(
|
return generate_pb2.GeneratedText(
|
||||||
text=self.text,
|
text=self.text,
|
||||||
generated_tokens=self.generated_tokens,
|
generated_tokens=self.generated_tokens,
|
||||||
finish_reason=self.finish_reason
|
finish_reason=self.finish_reason,
|
||||||
seed=self.seed
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -54,9 +54,7 @@ class PrefillTokens:
|
|||||||
|
|
||||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||||
return generate_pb2.PrefillTokens(
|
return generate_pb2.PrefillTokens(
|
||||||
ids=self.token_ids,
|
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts
|
||||||
logprobs=self.logprobs,
|
|
||||||
texts=self.texts
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -72,9 +70,13 @@ class Generation:
|
|||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None,
|
prefill_tokens=self.prefill_tokens.to_pb()
|
||||||
|
if self.prefill_tokens is not None
|
||||||
|
else None,
|
||||||
token_id=self.token_id,
|
token_id=self.token_id,
|
||||||
token_logprob=self.token_logprob,
|
token_logprob=self.token_logprob,
|
||||||
token_text=self.token_text,
|
token_text=self.token_text,
|
||||||
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,
|
generated_text=self.generated_text.to_pb()
|
||||||
|
if self.generated_text is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
@ -36,9 +36,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.PrefillResponse(
|
return generate_pb2.PrefillResponse(
|
||||||
generations=[
|
generations=[generation.to_pb() for generation in generations],
|
||||||
generation.to_pb() for generation in generations
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,9 +60,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[
|
generations=[generation.to_pb() for generation in generations],
|
||||||
generation.to_pb() for generation in generations
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user