mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
stuff
This commit is contained in:
parent
cbbc046a79
commit
5677540881
@ -177,7 +177,8 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base
|
FROM base
|
||||||
|
|
||||||
|
RUN git clone https://github.com/bigcode-project/bigcode-inference-benchmark.git && \
|
||||||
|
cd bigcode-inference-benchmark && git checkout text_gen_inference
|
||||||
|
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/usr/data/.hf_cache/
|
ENV HUGGINGFACE_HUB_CACHE=/usr/data/.hf_cache/
|
||||||
ENV PYTHONPATH=/usr/src/server/
|
ENV PYTHONPATH=/usr/src/server/
|
||||||
|
@ -8,6 +8,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
|
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
@ -155,6 +156,8 @@ def get_model(
|
|||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
if os.environ.get("VECTORIZED_LM") is not None:
|
||||||
|
return VectorizedCausalLM(model_id, revision, quantize=quantize)
|
||||||
return CausalLM(model_id, revision, quantize=quantize)
|
return CausalLM(model_id, revision, quantize=quantize)
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
|
@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -54,7 +53,6 @@ class CausalLMBatch(Batch):
|
|||||||
keys_head_dim_last: bool = True
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
#logger.info(f"to_pb, id={self.batch_id}, requests={self.requests}, size={len(self)}, max_tokens={self.max_tokens}")
|
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
@ -69,7 +67,6 @@ class CausalLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
#logger.info(f"from_pb, pb={pb}, tokenizer={tokenizer}, device={device}")
|
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -144,7 +141,6 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
|
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
|
||||||
logger.info(f"filter, requests={requests}")
|
|
||||||
if len(requests) == 0:
|
if len(requests) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
if len(requests) == len(self):
|
if len(requests) == len(self):
|
||||||
@ -242,7 +238,6 @@ class CausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
||||||
logger.info(f"concatenate, batches={batches}")
|
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
|
@ -20,19 +20,18 @@ tracer = trace.get_tracer(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMBatch(Batch):
|
class VectorizedCausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
requests_idx_mapping: Dict[int, int]
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: torch.Tensor
|
|
||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[torch.Tensor]
|
input_ids: torch.Tensor
|
||||||
|
|
||||||
# Lengths of all generations present in the batch
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
@ -45,16 +44,11 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
padding_right_offset: int
|
|
||||||
|
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
|
||||||
# Past metadata
|
|
||||||
keys_head_dim_last: bool = True
|
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
#logger.info(f"to_pb, id={self.batch_id}, requests={self.requests}, size={len(self)}, max_tokens={self.max_tokens}")
|
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
@ -68,8 +62,7 @@ class CausalLMBatch(Batch):
|
|||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "VectorizedCausalLMBatch":
|
||||||
#logger.info(f"from_pb, pb={pb}, tokenizer={tokenizer}, device={device}")
|
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -82,11 +75,14 @@ class CausalLMBatch(Batch):
|
|||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
|
next_token_chooser=NextTokenChooser.from_pb(r.parameters, device)
|
||||||
|
# TODO: Implement
|
||||||
|
assert len(next_token_chooser.warpers)==0
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
token_offsets.append(None)
|
token_offsets.append(None)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(next_token_chooser)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
@ -109,17 +105,19 @@ class CausalLMBatch(Batch):
|
|||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_shape=(pb.size, max_input_length + padding_right_offset)
|
||||||
# Allocate maximum attention_mask
|
|
||||||
attention_mask = input_ids.new_zeros(
|
|
||||||
(pb.size, max_input_length + padding_right_offset)
|
|
||||||
)
|
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
|
||||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
# Allocate maximum attention_mask
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
|
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
|
||||||
|
attention_mask[:, max_input_length:].fill_(1)
|
||||||
|
|
||||||
|
position_ids = attention_mask.cumsum(-1).sub_(1)
|
||||||
|
position_ids[:, :max_input_length].relu_()
|
||||||
|
|
||||||
|
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
||||||
|
input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"])
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
||||||
|
|
||||||
@ -127,327 +125,148 @@ class CausalLMBatch(Batch):
|
|||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=list(all_input_ids),
|
input_ids=input_ids,
|
||||||
input_lengths=input_lengths.tolist(),
|
input_lengths=input_lengths.tolist(),
|
||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
token_offsets=token_offsets,
|
token_offsets=token_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
|
def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]:
|
||||||
logger.info(f"filter, requests={requests}")
|
raise NotImplementedError()
|
||||||
if len(requests) == 0:
|
|
||||||
raise ValueError("Batch must have at least one request")
|
|
||||||
if len(requests) == len(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
keep_indices = []
|
|
||||||
|
|
||||||
# New values after filtering
|
|
||||||
requests_idx_mapping = {}
|
|
||||||
input_lengths = []
|
|
||||||
offsets = []
|
|
||||||
token_offsets = []
|
|
||||||
all_input_ids = []
|
|
||||||
max_input_length = 0
|
|
||||||
|
|
||||||
next_token_choosers = []
|
|
||||||
stopping_criterias = []
|
|
||||||
|
|
||||||
total_remaining_decode_tokens = 0
|
|
||||||
new_padding_right_offset = 0
|
|
||||||
|
|
||||||
for i, r in enumerate(requests):
|
|
||||||
idx = self.requests_idx_mapping[r.id]
|
|
||||||
requests_idx_mapping[r.id] = i
|
|
||||||
keep_indices.append(idx)
|
|
||||||
|
|
||||||
offsets.append(self.offsets[idx])
|
|
||||||
token_offsets.append(self.token_offsets[idx])
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
|
||||||
|
|
||||||
request_input_length = self.input_lengths[idx]
|
|
||||||
input_lengths.append(request_input_length)
|
|
||||||
max_input_length = max(max_input_length, request_input_length)
|
|
||||||
|
|
||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
|
||||||
stopping_criterias.append(stopping_criteria)
|
|
||||||
remaining_decode_tokens = (
|
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
|
||||||
)
|
|
||||||
total_remaining_decode_tokens += remaining_decode_tokens
|
|
||||||
new_padding_right_offset = max(
|
|
||||||
new_padding_right_offset, remaining_decode_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
|
||||||
input_ids = self.input_ids[keep_indices]
|
|
||||||
position_ids = self.position_ids[keep_indices]
|
|
||||||
self.attention_mask = self.attention_mask[
|
|
||||||
keep_indices,
|
|
||||||
-(self.padding_right_offset + max_input_length) : (
|
|
||||||
self.attention_mask.shape[1] - self.padding_right_offset
|
|
||||||
)
|
|
||||||
+ new_padding_right_offset,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Ensure that past_key_values tensors can be updated in-place
|
|
||||||
if type(self.past_key_values[0]) == tuple:
|
|
||||||
self.past_key_values = [list(layer) for layer in self.past_key_values]
|
|
||||||
|
|
||||||
# Update tensors in-place to allow incremental garbage collection
|
|
||||||
past_kv_length = max_input_length - 1
|
|
||||||
for layer in self.past_key_values:
|
|
||||||
past_keys, past_values = layer
|
|
||||||
if len(past_keys.shape) == 3:
|
|
||||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
|
||||||
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
|
||||||
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
|
||||||
if self.keys_head_dim_last:
|
|
||||||
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
|
|
||||||
else:
|
|
||||||
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
|
|
||||||
del past_keys
|
|
||||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
|
||||||
del past_values
|
|
||||||
|
|
||||||
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
|
|
||||||
|
|
||||||
self.requests = requests
|
|
||||||
self.requests_idx_mapping = requests_idx_mapping
|
|
||||||
self.input_ids = input_ids
|
|
||||||
self.position_ids = position_ids
|
|
||||||
self.all_input_ids = all_input_ids
|
|
||||||
self.input_lengths = input_lengths
|
|
||||||
self.offsets = offsets
|
|
||||||
self.token_offsets = token_offsets
|
|
||||||
self.next_token_choosers = next_token_choosers
|
|
||||||
self.stopping_criterias = stopping_criterias
|
|
||||||
self.max_input_length = max_input_length
|
|
||||||
self.padding_right_offset = new_padding_right_offset
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
|
||||||
logger.info(f"concatenate, batches={batches}")
|
raise NotImplementedError()
|
||||||
# Used for padding
|
|
||||||
total_batch_size = 0
|
|
||||||
max_input_length = 0
|
|
||||||
padding_right_offset = 0
|
|
||||||
for batch in batches:
|
|
||||||
total_batch_size += len(batch)
|
|
||||||
max_input_length = max(max_input_length, batch.max_input_length)
|
|
||||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
|
||||||
|
|
||||||
# Batch attributes
|
|
||||||
requests = []
|
|
||||||
requests_idx_mapping = {}
|
|
||||||
input_lengths = []
|
|
||||||
offsets = []
|
|
||||||
token_offsets = []
|
|
||||||
all_input_ids = []
|
|
||||||
next_token_choosers = []
|
|
||||||
stopping_criterias = []
|
|
||||||
max_tokens = 0
|
|
||||||
|
|
||||||
# Batch tensors
|
|
||||||
input_ids = None
|
|
||||||
attention_mask = None
|
|
||||||
position_ids = None
|
|
||||||
past_key_values = []
|
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
|
||||||
# Equivalent to a cumsum on batch sizes
|
|
||||||
start_index = 0
|
|
||||||
for i, batch in enumerate(batches):
|
|
||||||
requests.extend(batch.requests)
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
|
||||||
offsets.extend(batch.offsets)
|
|
||||||
token_offsets.extend(batch.token_offsets)
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
requests_idx_mapping = batch.requests_idx_mapping
|
|
||||||
else:
|
|
||||||
# We need to offset the mapping for each batch by the cumulative batch size
|
|
||||||
for k, v in batch.requests_idx_mapping.items():
|
|
||||||
requests_idx_mapping[k] = v + start_index
|
|
||||||
|
|
||||||
# Slicing end index for this batch
|
|
||||||
end_index = start_index + len(batch)
|
|
||||||
|
|
||||||
# We only concatenate batches that did at least one step
|
|
||||||
if batch.past_key_values is None:
|
|
||||||
raise ValueError("only concatenate prefilled batches")
|
|
||||||
|
|
||||||
# Create empty tensor
|
|
||||||
# input_ids is always of shape [batch_size, 1]
|
|
||||||
# We do not need to pad it
|
|
||||||
if input_ids is None:
|
|
||||||
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
|
||||||
# Copy to correct indices
|
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
|
||||||
|
|
||||||
# Create padded tensor
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = batch.attention_mask.new_zeros(
|
|
||||||
(total_batch_size, max_input_length + padding_right_offset),
|
|
||||||
)
|
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
|
||||||
# and to remove unused allocated space
|
|
||||||
left_offset = max_input_length - batch.max_input_length
|
|
||||||
batch_left_offset = (
|
|
||||||
batch.attention_mask.shape[1]
|
|
||||||
- batch.max_input_length
|
|
||||||
- batch.padding_right_offset
|
|
||||||
)
|
|
||||||
attention_mask[
|
|
||||||
start_index:end_index,
|
|
||||||
left_offset:-padding_right_offset,
|
|
||||||
] = batch.attention_mask[
|
|
||||||
:,
|
|
||||||
batch_left_offset : -batch.padding_right_offset,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create empty tensor
|
|
||||||
# position_ids is always of shape [batch_size, 1]
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
|
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
|
||||||
|
|
||||||
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
|
||||||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
|
||||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
|
||||||
# And ensure that we can update tensors in-place
|
|
||||||
if type(batch.past_key_values[0]) == tuple:
|
|
||||||
batch.past_key_values = [
|
|
||||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
|
||||||
for layer in batch.past_key_values
|
|
||||||
]
|
|
||||||
elif len(batch.past_key_values[0][0].shape) == 3:
|
|
||||||
for layer in batch.past_key_values:
|
|
||||||
for k, t in enumerate(layer):
|
|
||||||
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
|
|
||||||
|
|
||||||
# Add eventual padding tokens that were added while concatenating
|
|
||||||
max_tokens += batch.max_tokens + (
|
|
||||||
max_input_length - batch.max_input_length
|
|
||||||
) * len(batch)
|
|
||||||
|
|
||||||
start_index = end_index
|
|
||||||
|
|
||||||
first_past_kvs = batches[0].past_key_values
|
|
||||||
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
|
|
||||||
|
|
||||||
padded_past_values_shape = (
|
|
||||||
total_batch_size,
|
|
||||||
num_heads,
|
|
||||||
max_input_length - 1,
|
|
||||||
head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batches[0].keys_head_dim_last:
|
|
||||||
padded_past_keys_shape = padded_past_values_shape
|
|
||||||
else:
|
|
||||||
# seq_length is last for BLOOM
|
|
||||||
padded_past_keys_shape = (
|
|
||||||
total_batch_size,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
max_input_length - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iterate over attention layers
|
|
||||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
|
||||||
for j in range(len(first_past_kvs)):
|
|
||||||
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
|
|
||||||
start_index = 0
|
|
||||||
for batch in batches:
|
|
||||||
past_keys = batch.past_key_values[j][0]
|
|
||||||
# Clear reference to the original tensor
|
|
||||||
batch.past_key_values[j][0] = None
|
|
||||||
|
|
||||||
# Slicing end index for this batch
|
|
||||||
end_index = start_index + len(batch)
|
|
||||||
# We slice the keys to remove the padding from previous batches
|
|
||||||
past_seq_len = batch.max_input_length - 1
|
|
||||||
if batch.keys_head_dim_last:
|
|
||||||
padded_past_keys[
|
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
|
||||||
] = past_keys[:, :, -past_seq_len:, :]
|
|
||||||
else:
|
|
||||||
# BLOOM case
|
|
||||||
padded_past_keys[
|
|
||||||
start_index:end_index, :, :, -past_seq_len:
|
|
||||||
] = past_keys[:, :, :, -past_seq_len:]
|
|
||||||
del past_keys
|
|
||||||
|
|
||||||
start_index = end_index
|
|
||||||
|
|
||||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
|
||||||
padded_past_values_shape
|
|
||||||
)
|
|
||||||
start_index = 0
|
|
||||||
for batch in batches:
|
|
||||||
past_values = batch.past_key_values[j][1]
|
|
||||||
# Clear reference to the original tensor
|
|
||||||
batch.past_key_values[j][1] = None
|
|
||||||
|
|
||||||
# Slicing end index for this batch
|
|
||||||
end_index = start_index + len(batch)
|
|
||||||
# We slice the past values to remove the padding from previous batches
|
|
||||||
past_seq_len = batch.max_input_length - 1
|
|
||||||
padded_past_values[
|
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
|
||||||
] = past_values[:, :, -past_seq_len:, :]
|
|
||||||
del past_values
|
|
||||||
|
|
||||||
# Update values
|
|
||||||
start_index = end_index
|
|
||||||
|
|
||||||
past_key_values.append([padded_past_keys, padded_past_values])
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
batch_id=batches[0].batch_id,
|
|
||||||
requests=requests,
|
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
offsets=offsets,
|
|
||||||
token_offsets=token_offsets,
|
|
||||||
next_token_choosers=next_token_choosers,
|
|
||||||
stopping_criterias=stopping_criterias,
|
|
||||||
max_input_length=max_input_length,
|
|
||||||
padding_right_offset=padding_right_offset,
|
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class VectorizedNextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size:int,
|
||||||
|
watermark=None,
|
||||||
|
temperature=None,
|
||||||
|
repetition_penalty=None,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
typical_p=None,
|
||||||
|
do_sample=None,
|
||||||
|
seed:int=0,
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
|
self.batch_size=batch_size
|
||||||
|
|
||||||
|
do_sample=self._standardize(do_sample, False)
|
||||||
|
|
||||||
|
watermark=self._standardize(watermark, False)
|
||||||
|
if any(watermark):
|
||||||
|
raise NotImplementedError("Watermarking not implemented")
|
||||||
|
|
||||||
|
repetition_penalty=self._standardize(repetition_penalty, 1.0)
|
||||||
|
if any([x!=1.0 for x in repetition_penalty]):
|
||||||
|
self.repetition_penalty=torch.tensor([repetition_penalty], dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.repetition_penalty=None
|
||||||
|
|
||||||
|
temperature=self._standardize(temperature, 1.0)
|
||||||
|
if any([x!=1.0 for x in temperature]):
|
||||||
|
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
|
||||||
|
self.temperature=torch.tensor([temperature], dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.temperature=None
|
||||||
|
|
||||||
|
top_k=self._standardize(top_k, 0)
|
||||||
|
if any([x!=0 for x in top_k]):
|
||||||
|
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||||
|
self.top_k=torch.tensor([top_k], dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.top_k=None
|
||||||
|
|
||||||
|
|
||||||
|
top_p=self._standardize(top_p, 1.0)
|
||||||
|
if any([x<1.0 for x in top_p]):
|
||||||
|
raise NotImplementedError("Top P not implemented")
|
||||||
|
|
||||||
|
typical_p=self._standardize(typical_p, 1.0)
|
||||||
|
if any([x<1.0 for x in typical_p]):
|
||||||
|
raise NotImplementedError("Typical P not implemented")
|
||||||
|
|
||||||
|
self.do_sample = any(do_sample)
|
||||||
|
if self.do_sample and not all(do_sample):
|
||||||
|
raise NotImplementedError("Mixed greedy and probabilistic sampling not supported")
|
||||||
|
|
||||||
|
def _standardize(self, values, default):
|
||||||
|
if isinstance(values, list):
|
||||||
|
values=values.copy()
|
||||||
|
else:
|
||||||
|
values=[values]*self.batch_size
|
||||||
|
assert len(values)==self.batch_size
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v is None:
|
||||||
|
values[i]=default
|
||||||
|
return values
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
# Only process the last token
|
||||||
|
scores=scores[: -1, :]
|
||||||
|
|
||||||
|
if self.repetition_penalty is not None:
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
|
||||||
|
if self.temperature is not None:
|
||||||
|
scores.div_(self.temperature)
|
||||||
|
|
||||||
|
if self.top_k is not None:
|
||||||
|
top_k = min(self.top_k, scores.size(-1)) # Safety check
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
|
||||||
|
# Compute logprobs
|
||||||
|
logprobs = torch.log_softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
if self.do_sample:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
next_token_ids = torch.argmax(scores, dim=-1)
|
||||||
|
|
||||||
|
return next_token_ids, logprobs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
|
device: torch.device,
|
||||||
|
) -> "VectorizedNextTokenChooser":
|
||||||
|
# TODO: Seeds are ignored
|
||||||
|
return VectorizedNextTokenChooser(
|
||||||
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
|
temperature=[pb_.temperature for pb_ in pb],
|
||||||
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||||
|
top_k=[pb_.top_k for pb_ in pb],
|
||||||
|
top_p=[pb_.top_p for pb_ in pb],
|
||||||
|
typical_p=[pb_.typical_p for pb_ in pb],
|
||||||
|
do_sample=[pb_.do_sample for pb_ in pb],
|
||||||
|
seed=0,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VectorizedCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -457,6 +276,7 @@ class CausalLM(Model):
|
|||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
# TODO: Choose dtype (fp16?)
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
@ -482,7 +302,7 @@ class CausalLM(Model):
|
|||||||
else self.model.config.eos_token_id
|
else self.model.config.eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super().__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -491,94 +311,58 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[VectorizedCausalLMBatch]:
|
||||||
return CausalLMBatch
|
return VectorizedCausalLMBatch
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
# Model Forward
|
|
||||||
outputs = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
return outputs.logits, outputs.past_key_values
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: VectorizedCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
|
||||||
# slice the attention mask to the correct shape
|
key_length=batch.max_input_length
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
query_length=key_length if batch.past_key_values is None else 1
|
||||||
|
|
||||||
logits, past = self.forward(
|
outputs = self.model.forward(
|
||||||
batch.input_ids,
|
input_ids=batch.input_ids[:, key_length-query_length: key_length],
|
||||||
attention_mask,
|
attention_mask=batch.attention_mask[:, : key_length],
|
||||||
batch.position_ids,
|
position_ids=batch.position_ids[:, key_length-query_length: key_length],
|
||||||
batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
# TODO: Post-processing
|
||||||
|
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
|
||||||
|
|
||||||
|
# Update batch
|
||||||
|
# TODO: Why do we need all input ids?
|
||||||
|
batch.input_ids[:, key_length].copy_(next_token_ids)
|
||||||
|
batch.past_key_values=outputs.past_key_values
|
||||||
|
batch.input_lengths=[length+1 for length in batch.input_lengths]
|
||||||
|
batch.max_input_length+=1
|
||||||
|
|
||||||
|
# TODO: self.decode_token, offsets?
|
||||||
|
next_token_ids=next_token_ids.cpu().tolist()
|
||||||
|
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
|
||||||
|
|
||||||
|
# TODO: Vectorize some of this?
|
||||||
|
|
||||||
# Results
|
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
next_batch=None
|
||||||
|
|
||||||
# Zipped iterator
|
for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)):
|
||||||
iterator = zip(
|
stopping_criterias=batch.stopping_criterias[i]
|
||||||
batch.requests,
|
next_token_chooser=batch.next_token_choosers[i]
|
||||||
batch.input_lengths,
|
stop, reason = stopping_criterias(
|
||||||
batch.offsets,
|
next_token_id,
|
||||||
batch.token_offsets,
|
|
||||||
logits,
|
|
||||||
batch.next_token_choosers,
|
|
||||||
batch.stopping_criterias,
|
|
||||||
batch.all_input_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For each member of the batch
|
|
||||||
for i, (
|
|
||||||
request,
|
|
||||||
input_length,
|
|
||||||
offset,
|
|
||||||
token_offset,
|
|
||||||
logits,
|
|
||||||
next_token_chooser,
|
|
||||||
stopping_criteria,
|
|
||||||
all_input_ids,
|
|
||||||
) in enumerate(iterator):
|
|
||||||
# Select next token
|
|
||||||
next_token_id, logprobs = next_token_chooser(
|
|
||||||
all_input_ids.view(1, -1), logits
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
|
||||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
|
||||||
new_input_length = input_length + 1
|
|
||||||
|
|
||||||
# Generated token
|
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
|
||||||
next_token_text, offset, token_offset = self.decode_token(
|
|
||||||
all_input_ids[:, 0], offset, token_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
|
||||||
stop, reason = stopping_criteria(
|
|
||||||
next_token_id_squeezed,
|
|
||||||
next_token_text,
|
next_token_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
|
# TODO: Same as stopping_criteria.current_output?
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
batch.input_ids[i, -stopping_criterias.current_tokens :]
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
@ -587,67 +371,24 @@ class CausalLM(Model):
|
|||||||
seed = None
|
seed = None
|
||||||
|
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
output_text, stopping_criterias.current_tokens, reason, seed
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Keep request in the batch
|
# Keep request in the batch
|
||||||
generated_text = None
|
generated_text = None
|
||||||
stopped = False
|
next_batch = batch
|
||||||
|
|
||||||
# Prefill
|
|
||||||
if stopping_criteria.current_tokens == 1:
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
|
||||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
|
||||||
1, all_input_ids[1:]
|
|
||||||
).squeeze(1)[-new_input_length:-1].tolist()
|
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
prefill_tokens = PrefillTokens(
|
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefill_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
batch.requests[i].id,
|
||||||
prefill_tokens,
|
None,
|
||||||
next_token_id_squeezed,
|
next_token_id,
|
||||||
next_token_logprob,
|
0,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
return generations, next_batch
|
||||||
batch.input_ids[i, 0] = next_token_id
|
|
||||||
batch.all_input_ids[i] = all_input_ids
|
|
||||||
batch.input_lengths[i] = new_input_length
|
|
||||||
batch.offsets[i] = offset
|
|
||||||
batch.token_offsets[i] = token_offset
|
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
|
||||||
if stopped:
|
|
||||||
return generations, None
|
|
||||||
|
|
||||||
# Slice unused values from prefill
|
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
|
||||||
# Decrease right offset
|
|
||||||
batch.padding_right_offset -= 1
|
|
||||||
|
|
||||||
# Update position_ids
|
|
||||||
batch.position_ids = batch.position_ids[:, -1:] + 1
|
|
||||||
|
|
||||||
# Update past key values
|
|
||||||
batch.past_key_values = past
|
|
||||||
|
|
||||||
return generations, batch
|
|
||||||
|
Loading…
Reference in New Issue
Block a user