mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(server): avoid manipulating position_ids for non-applicable models
Currently position_ids are maintained/updated in the CausalLM case but this is unnecessary for models like BLOOM which don't use them.
This commit is contained in:
parent
b927244eb5
commit
a07ef4c656
@ -1,3 +1,5 @@
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -26,7 +28,7 @@ class CausalLMBatch(Batch):
|
||||
# Decoder values
|
||||
input_ids: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
position_ids: Optional[torch.Tensor]
|
||||
past_key_values: Optional[List[Tuple]]
|
||||
|
||||
# All tokens
|
||||
@ -62,6 +64,7 @@ class CausalLMBatch(Batch):
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
use_position_ids: bool = False,
|
||||
) -> "CausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
@ -105,10 +108,13 @@ class CausalLMBatch(Batch):
|
||||
)
|
||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||
|
||||
if use_position_ids:
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
@ -204,6 +210,7 @@ class CausalLMBatch(Batch):
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
|
||||
if batch.position_ids is not None:
|
||||
# Create empty tensor
|
||||
# position_ids is always of shape [batch_size, 1]
|
||||
if position_ids is None:
|
||||
@ -328,6 +335,10 @@ class CausalLM(Model):
|
||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||
)
|
||||
|
||||
self.use_position_ids = "position_ids" in set(
|
||||
inspect.signature(self.model.forward).parameters.keys()
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
@ -494,7 +505,8 @@ class CausalLM(Model):
|
||||
if len(next_batch_keep_indices) != len(batch):
|
||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
|
||||
next_batch_position_ids = batch.position_ids[next_batch_keep_indices] \
|
||||
if batch.position_ids is not None else None
|
||||
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
||||
next_batch_past_key_values = [
|
||||
[
|
||||
@ -522,6 +534,7 @@ class CausalLM(Model):
|
||||
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
|
||||
|
||||
# Update position_ids
|
||||
if next_batch_position_ids is not None:
|
||||
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
|
||||
|
||||
next_batch = CausalLMBatch(
|
||||
|
@ -23,6 +23,7 @@ class Model(ABC):
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
self.device = device
|
||||
self.decode_buffer = decode_buffer
|
||||
self.use_position_ids = False
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -39,8 +39,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
return generate_pb2.ClearCacheResponse()
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
kwargs = {"use_position_ids": True} if self.model.use_position_ids else {}
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.device
|
||||
request.batch, self.model.tokenizer, self.model.device, **kwargs,
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
|
Loading…
Reference in New Issue
Block a user