mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(server): avoid manipulating position_ids for non-applicable models
Currently position_ids are maintained/updated in the CausalLM case but this is unnecessary for models like BLOOM which don't use them.
This commit is contained in:
parent
b927244eb5
commit
a07ef4c656
@ -1,3 +1,5 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -26,7 +28,7 @@ class CausalLMBatch(Batch):
|
|||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: Optional[torch.Tensor]
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
@ -62,6 +64,7 @@ class CausalLMBatch(Batch):
|
|||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
use_position_ids: bool = False,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -105,11 +108,14 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
|
if use_position_ids:
|
||||||
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
@ -204,11 +210,12 @@ class CausalLMBatch(Batch):
|
|||||||
batch_left_offset : -batch.padding_right_offset,
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create empty tensor
|
if batch.position_ids is not None:
|
||||||
# position_ids is always of shape [batch_size, 1]
|
# Create empty tensor
|
||||||
if position_ids is None:
|
# position_ids is always of shape [batch_size, 1]
|
||||||
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
|
if position_ids is None:
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
|
||||||
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
|
|
||||||
for j, past in enumerate(batch.past_key_values):
|
for j, past in enumerate(batch.past_key_values):
|
||||||
past_keys, past_values = past
|
past_keys, past_values = past
|
||||||
@ -328,6 +335,10 @@ class CausalLM(Model):
|
|||||||
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
tokenizer=tokenizer, device=device, decode_buffer=decode_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_position_ids = "position_ids" in set(
|
||||||
|
inspect.signature(self.model.forward).parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
@ -494,7 +505,8 @@ class CausalLM(Model):
|
|||||||
if len(next_batch_keep_indices) != len(batch):
|
if len(next_batch_keep_indices) != len(batch):
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
|
next_batch_position_ids = batch.position_ids[next_batch_keep_indices] \
|
||||||
|
if batch.position_ids is not None else None
|
||||||
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
||||||
next_batch_past_key_values = [
|
next_batch_past_key_values = [
|
||||||
[
|
[
|
||||||
@ -522,7 +534,8 @@ class CausalLM(Model):
|
|||||||
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
|
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
|
if next_batch_position_ids is not None:
|
||||||
|
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
|
||||||
|
|
||||||
next_batch = CausalLMBatch(
|
next_batch = CausalLMBatch(
|
||||||
batch_id=batch.batch_id,
|
batch_id=batch.batch_id,
|
||||||
|
@ -23,6 +23,7 @@ class Model(ABC):
|
|||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.decode_buffer = decode_buffer
|
self.decode_buffer = decode_buffer
|
||||||
|
self.use_position_ids = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -39,8 +39,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.ClearCacheResponse()
|
return generate_pb2.ClearCacheResponse()
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
|
kwargs = {"use_position_ids": True} if self.model.use_position_ids else {}
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.device
|
request.batch, self.model.tokenizer, self.model.device, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user