This commit is contained in:
Joel Lamy-Poirier 2023-05-03 11:16:35 -04:00
parent cbbc046a79
commit 5677540881
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF
4 changed files with 189 additions and 449 deletions

View File

@ -177,7 +177,8 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image # Final image
FROM base FROM base
RUN git clone https://github.com/bigcode-project/bigcode-inference-benchmark.git && \
cd bigcode-inference-benchmark && git checkout text_gen_inference
ENV HUGGINGFACE_HUB_CACHE=/usr/data/.hf_cache/ ENV HUGGINGFACE_HUB_CACHE=/usr/data/.hf_cache/
ENV PYTHONPATH=/usr/src/server/ ENV PYTHONPATH=/usr/src/server/

View File

@ -8,6 +8,7 @@ from typing import Optional
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -155,6 +156,8 @@ def get_model(
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
if os.environ.get("VECTORIZED_LM") is not None:
return VectorizedCausalLM(model_id, revision, quantize=quantize)
return CausalLM(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(model_id, revision, quantize=quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from loguru import logger
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -54,7 +53,6 @@ class CausalLMBatch(Batch):
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
#logger.info(f"to_pb, id={self.batch_id}, requests={self.requests}, size={len(self)}, max_tokens={self.max_tokens}")
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -69,7 +67,6 @@ class CausalLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
#logger.info(f"from_pb, pb={pb}, tokenizer={tokenizer}, device={device}")
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -144,7 +141,6 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]:
logger.info(f"filter, requests={requests}")
if len(requests) == 0: if len(requests) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
if len(requests) == len(self): if len(requests) == len(self):
@ -242,7 +238,6 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
logger.info(f"concatenate, batches={batches}")
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
max_input_length = 0 max_input_length = 0

View File

@ -20,19 +20,18 @@ tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class CausalLMBatch(Batch): class VectorizedCausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int] requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor attention_mask: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
# All tokens # All tokens
all_input_ids: List[torch.Tensor] input_ids: torch.Tensor
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
@ -45,16 +44,11 @@ class CausalLMBatch(Batch):
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to # Maximum number of tokens this batch will grow to
max_tokens: int max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
#logger.info(f"to_pb, id={self.batch_id}, requests={self.requests}, size={len(self)}, max_tokens={self.max_tokens}")
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -68,8 +62,7 @@ class CausalLMBatch(Batch):
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "VectorizedCausalLMBatch":
#logger.info(f"from_pb, pb={pb}, tokenizer={tokenizer}, device={device}")
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -82,11 +75,14 @@ class CausalLMBatch(Batch):
padding_right_offset = 0 padding_right_offset = 0
max_decode_tokens = 0 max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
next_token_chooser=NextTokenChooser.from_pb(r.parameters, device)
# TODO: Implement
assert len(next_token_chooser.warpers)==0
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(next_token_chooser)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -109,17 +105,19 @@ class CausalLMBatch(Batch):
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"] input_shape=(pb.size, max_input_length + padding_right_offset)
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 # Allocate maximum attention_mask
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"])
attention_mask[:, max_input_length:].fill_(1)
position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_()
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"])
max_tokens = len(inputs) * max_input_length + max_decode_tokens max_tokens = len(inputs) * max_input_length + max_decode_tokens
@ -127,327 +125,148 @@ class CausalLMBatch(Batch):
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=list(all_input_ids), input_ids=input_ids,
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, offsets=offsets,
token_offsets=token_offsets, token_offsets=token_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]:
logger.info(f"filter, requests={requests}") raise NotImplementedError()
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
if len(requests) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.offsets = offsets
self.token_offsets = token_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
logger.info(f"concatenate, batches={batches}") raise NotImplementedError()
# Used for padding
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
start_index = end_index
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
padded_past_values_shape = (
total_batch_size,
num_heads,
max_input_length - 1,
head_dim,
)
if batches[0].keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_input_length - 1,
)
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
del past_keys
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = past_values[:, :, -past_seq_len:, :]
del past_values
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class CausalLM(Model): class VectorizedNextTokenChooser:
def __init__(
self,
batch_size:int,
watermark=None,
temperature=None,
repetition_penalty=None,
top_k=None,
top_p=None,
typical_p=None,
do_sample=None,
seed:int=0,
device="cpu",
):
self.batch_size=batch_size
do_sample=self._standardize(do_sample, False)
watermark=self._standardize(watermark, False)
if any(watermark):
raise NotImplementedError("Watermarking not implemented")
repetition_penalty=self._standardize(repetition_penalty, 1.0)
if any([x!=1.0 for x in repetition_penalty]):
self.repetition_penalty=torch.tensor([repetition_penalty], dtype=torch.float32, device=device).unsqueeze(1)
else:
self.repetition_penalty=None
temperature=self._standardize(temperature, 1.0)
if any([x!=1.0 for x in temperature]):
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
self.temperature=torch.tensor([temperature], dtype=torch.float32, device=device).unsqueeze(1)
else:
self.temperature=None
top_k=self._standardize(top_k, 0)
if any([x!=0 for x in top_k]):
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
self.top_k=torch.tensor([top_k], dtype=torch.float32, device=device).unsqueeze(1)
else:
self.top_k=None
top_p=self._standardize(top_p, 1.0)
if any([x<1.0 for x in top_p]):
raise NotImplementedError("Top P not implemented")
typical_p=self._standardize(typical_p, 1.0)
if any([x<1.0 for x in typical_p]):
raise NotImplementedError("Typical P not implemented")
self.do_sample = any(do_sample)
if self.do_sample and not all(do_sample):
raise NotImplementedError("Mixed greedy and probabilistic sampling not supported")
def _standardize(self, values, default):
if isinstance(values, list):
values=values.copy()
else:
values=[values]*self.batch_size
assert len(values)==self.batch_size
for i, v in enumerate(values):
if v is None:
values[i]=default
return values
def __call__(self, input_ids, scores):
# Only process the last token
scores=scores[: -1, :]
if self.repetition_penalty is not None:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)
scores.scatter_(1, input_ids, score)
if self.temperature is not None:
scores.div_(self.temperature)
if self.top_k is not None:
top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
# Compute logprobs
logprobs = torch.log_softmax(scores, dim=-1)
if self.do_sample:
raise NotImplementedError()
else:
next_token_ids = torch.argmax(scores, dim=-1)
return next_token_ids, logprobs
@classmethod
def from_pb(
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
device: torch.device,
) -> "VectorizedNextTokenChooser":
# TODO: Seeds are ignored
return VectorizedNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
do_sample=[pb_.do_sample for pb_ in pb],
seed=0,
device=device,
)
class VectorizedCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -457,6 +276,7 @@ class CausalLM(Model):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
# TODO: Choose dtype (fp16?)
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
if quantize: if quantize:
@ -482,7 +302,7 @@ class CausalLM(Model):
else self.model.config.eos_token_id else self.model.config.eos_token_id
) )
super(CausalLM, self).__init__( super().__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
@ -491,94 +311,58 @@ class CausalLM(Model):
) )
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[VectorizedCausalLMBatch]:
return CausalLMBatch return VectorizedCausalLMBatch
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
) )
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: VectorizedCausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
# slice the attention mask to the correct shape key_length=batch.max_input_length
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] query_length=key_length if batch.past_key_values is None else 1
logits, past = self.forward( outputs = self.model.forward(
batch.input_ids, input_ids=batch.input_ids[:, key_length-query_length: key_length],
attention_mask, attention_mask=batch.attention_mask[:, : key_length],
batch.position_ids, position_ids=batch.position_ids[:, key_length-query_length: key_length],
batch.past_key_values, past_key_values=batch.past_key_values,
) )
# TODO: Post-processing
next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1)
# Update batch
# TODO: Why do we need all input ids?
batch.input_ids[:, key_length].copy_(next_token_ids)
batch.past_key_values=outputs.past_key_values
batch.input_lengths=[length+1 for length in batch.input_lengths]
batch.max_input_length+=1
# TODO: self.decode_token, offsets?
next_token_ids=next_token_ids.cpu().tolist()
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
# TODO: Vectorize some of this?
# Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True next_batch=None
# Zipped iterator for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)):
iterator = zip( stopping_criterias=batch.stopping_criterias[i]
batch.requests, next_token_chooser=batch.next_token_choosers[i]
batch.input_lengths, stop, reason = stopping_criterias(
batch.offsets, next_token_id,
batch.token_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
offset,
token_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token(
all_input_ids[:, 0], offset, token_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text, next_token_text,
) )
if stop: if stop:
# Decode generated tokens # Decode generated tokens
# TODO: Same as stopping_criteria.current_output?
output_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0] batch.input_ids[i, -stopping_criterias.current_tokens :]
) )
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -587,67 +371,24 @@ class CausalLM(Model):
seed = None seed = None
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed output_text, stopping_criterias.current_tokens, reason, seed
) )
else: else:
# Keep request in the batch # Keep request in the batch
generated_text = None generated_text = None
stopped = False next_batch = batch
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids[1:]
).squeeze(1)[-new_input_length:-1].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation( generation = Generation(
request.id, batch.requests[i].id,
prefill_tokens, None,
next_token_id_squeezed, next_token_id,
next_token_logprob, 0,
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id in self.all_special_ids,
generated_text, generated_text,
) )
generations.append(generation) generations.append(generation)
# Update values return generations, next_batch
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1
# Update past key values
batch.past_key_values = past
return generations, batch