feat(server): avoid manipulating position_ids for non-applicable models

Currently position_ids are maintained/updated in the CausalLM case but this is unnecessary for models like BLOOM which don't use them.
This commit is contained in:
Nick Hill 2023-04-17 16:48:01 -07:00
parent b927244eb5
commit a07ef4c656
3 changed files with 27 additions and 12 deletions

View File

@ -1,3 +1,5 @@
import inspect
import torch
from dataclasses import dataclass
@ -26,7 +28,7 @@ class CausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
position_ids: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
# All tokens
@ -62,6 +64,7 @@ class CausalLMBatch(Batch):
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
use_position_ids: bool = False,
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
@ -105,11 +108,14 @@ class CausalLMBatch(Batch):
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
if use_position_ids:
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
else:
position_ids = None
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -204,11 +210,12 @@ class CausalLMBatch(Batch):
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
if batch.position_ids is not None:
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past
@ -328,6 +335,10 @@ class CausalLM(Model):
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
)
self.use_position_ids = "position_ids" in set(
inspect.signature(self.model.forward).parameters.keys()
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
@ -494,7 +505,8 @@ class CausalLM(Model):
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices] \
if batch.position_ids is not None else None
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [
[
@ -522,7 +534,8 @@ class CausalLM(Model):
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
if next_batch_position_ids is not None:
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
next_batch = CausalLMBatch(
batch_id=batch.batch_id,

View File

@ -23,6 +23,7 @@ class Model(ABC):
self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device
self.decode_buffer = decode_buffer
self.use_position_ids = False
@property
@abstractmethod

View File

@ -39,8 +39,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse()
async def Prefill(self, request, context):
kwargs = {"use_position_ids": True} if self.model.use_position_ids else {}
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
request.batch, self.model.tokenizer, self.model.device, **kwargs,
)
generations, next_batch = self.model.generate_token(batch)