feat(server): avoid manipulating position_ids for non-applicable models

Currently position_ids are maintained/updated in the CausalLM case but this is unnecessary for models like BLOOM which don't use them.
This commit is contained in:
Nick Hill 2023-04-17 16:48:01 -07:00
parent b927244eb5
commit a07ef4c656
3 changed files with 27 additions and 12 deletions

View File

@ -1,3 +1,5 @@
import inspect
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@ -26,7 +28,7 @@ class CausalLMBatch(Batch):
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
attention_mask: torch.Tensor attention_mask: torch.Tensor
position_ids: torch.Tensor position_ids: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
# All tokens # All tokens
@ -62,6 +64,7 @@ class CausalLMBatch(Batch):
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
use_position_ids: bool = False,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -105,11 +108,14 @@ class CausalLMBatch(Batch):
) )
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
if use_position_ids:
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
else:
position_ids = None
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -204,11 +210,12 @@ class CausalLMBatch(Batch):
batch_left_offset : -batch.padding_right_offset, batch_left_offset : -batch.padding_right_offset,
] ]
# Create empty tensor if batch.position_ids is not None:
# position_ids is always of shape [batch_size, 1] # Create empty tensor
if position_ids is None: # position_ids is always of shape [batch_size, 1]
position_ids = batch.position_ids.new_empty((total_batch_size, 1)) if position_ids is None:
position_ids[start_index:end_index] = batch.position_ids position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past past_keys, past_values = past
@ -328,6 +335,10 @@ class CausalLM(Model):
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
) )
self.use_position_ids = "position_ids" in set(
inspect.signature(self.model.forward).parameters.keys()
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return CausalLMBatch
@ -494,7 +505,8 @@ class CausalLM(Model):
if len(next_batch_keep_indices) != len(batch): if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_position_ids = batch.position_ids[next_batch_keep_indices] next_batch_position_ids = batch.position_ids[next_batch_keep_indices] \
if batch.position_ids is not None else None
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_past_key_values = [ next_batch_past_key_values = [
[ [
@ -522,7 +534,8 @@ class CausalLM(Model):
next_batch_attention_mask[:, -batch.padding_right_offset] = 1 next_batch_attention_mask[:, -batch.padding_right_offset] = 1
# Update position_ids # Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 if next_batch_position_ids is not None:
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
next_batch = CausalLMBatch( next_batch = CausalLMBatch(
batch_id=batch.batch_id, batch_id=batch.batch_id,

View File

@ -23,6 +23,7 @@ class Model(ABC):
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device self.device = device
self.decode_buffer = decode_buffer self.decode_buffer = decode_buffer
self.use_position_ids = False
@property @property
@abstractmethod @abstractmethod

View File

@ -39,8 +39,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
kwargs = {"use_position_ids": True} if self.model.use_position_ids else {}
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device request.batch, self.model.tokenizer, self.model.device, **kwargs,
) )
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)