From 5677540881a979ea1d70ffde209b48bf1b65a553 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 May 2023 11:16:35 -0400 Subject: [PATCH] stuff --- Dockerfile | 3 +- .../text_generation_server/models/__init__.py | 3 + .../models/causal_lm.py | 5 - .../models/vectorized_causal_lm.py | 627 +++++------------- 4 files changed, 189 insertions(+), 449 deletions(-) diff --git a/Dockerfile b/Dockerfile index ebfc0fab..9631d513 100644 --- a/Dockerfile +++ b/Dockerfile @@ -177,7 +177,8 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base - +RUN git clone https://github.com/bigcode-project/bigcode-inference-benchmark.git && \ + cd bigcode-inference-benchmark && git checkout text_gen_inference ENV HUGGINGFACE_HUB_CACHE=/usr/data/.hf_cache/ ENV PYTHONPATH=/usr/src/server/ diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 6f0d0769..f2a472dc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -8,6 +8,7 @@ from typing import Optional from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.vectorized_causal_lm import VectorizedCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -155,6 +156,8 @@ def get_model( raise ValueError("sharded is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + if os.environ.get("VECTORIZED_LM") is not None: + return VectorizedCausalLM(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 73fa1930..ec32bdce 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict -from loguru import logger from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -54,7 +53,6 @@ class CausalLMBatch(Batch): keys_head_dim_last: bool = True def to_pb(self) -> generate_pb2.Batch: - #logger.info(f"to_pb, id={self.batch_id}, requests={self.requests}, size={len(self)}, max_tokens={self.max_tokens}") return generate_pb2.Batch( id=self.batch_id, requests=self.requests, @@ -69,7 +67,6 @@ class CausalLMBatch(Batch): tokenizer: PreTrainedTokenizerBase, device: torch.device, ) -> "CausalLMBatch": - #logger.info(f"from_pb, pb={pb}, tokenizer={tokenizer}, device={device}") inputs = [] next_token_choosers = [] stopping_criterias = [] @@ -144,7 +141,6 @@ class CausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: - logger.info(f"filter, requests={requests}") if len(requests) == 0: raise ValueError("Batch must have at least one request") if len(requests) == len(self): @@ -242,7 +238,6 @@ class CausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": - logger.info(f"concatenate, batches={batches}") # Used for padding total_batch_size = 0 max_input_length = 0 diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 73fa1930..f0b568b9 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -20,19 +20,18 @@ tracer = trace.get_tracer(__name__) @dataclass -class CausalLMBatch(Batch): +class VectorizedCausalLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor past_key_values: Optional[List[Tuple]] # All tokens - all_input_ids: List[torch.Tensor] + input_ids: torch.Tensor # Lengths of all generations present in the batch input_lengths: List[int] @@ -45,16 +44,11 @@ class CausalLMBatch(Batch): # Metadata used for padding max_input_length: int - padding_right_offset: int # Maximum number of tokens this batch will grow to max_tokens: int - # Past metadata - keys_head_dim_last: bool = True - def to_pb(self) -> generate_pb2.Batch: - #logger.info(f"to_pb, id={self.batch_id}, requests={self.requests}, size={len(self)}, max_tokens={self.max_tokens}") return generate_pb2.Batch( id=self.batch_id, requests=self.requests, @@ -68,8 +62,7 @@ class CausalLMBatch(Batch): pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device, - ) -> "CausalLMBatch": - #logger.info(f"from_pb, pb={pb}, tokenizer={tokenizer}, device={device}") + ) -> "VectorizedCausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] @@ -82,11 +75,14 @@ class CausalLMBatch(Batch): padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): + next_token_chooser=NextTokenChooser.from_pb(r.parameters, device) + # TODO: Implement + assert len(next_token_chooser.warpers)==0 requests_idx_mapping[r.id] = i inputs.append(r.inputs) offsets.append(None) token_offsets.append(None) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(next_token_chooser) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -109,17 +105,19 @@ class CausalLMBatch(Batch): input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() - input_ids = tokenized_inputs["input_ids"] - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + input_shape=(pb.size, max_input_length + padding_right_offset) - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + # Allocate maximum attention_mask + attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) + # Copy tokenizer attention_mask into fully allocated attention_mask + attention_mask[:, :max_input_length].copy_(tokenized_inputs["attention_mask"]) + attention_mask[:, max_input_length:].fill_(1) + + position_ids = attention_mask.cumsum(-1).sub_(1) + position_ids[:, :max_input_length].relu_() + + input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) + input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"]) max_tokens = len(inputs) * max_input_length + max_decode_tokens @@ -127,327 +125,148 @@ class CausalLMBatch(Batch): batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - all_input_ids=list(all_input_ids), + input_ids=input_ids, input_lengths=input_lengths.tolist(), offsets=offsets, token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) @tracer.start_as_current_span("filter") - def filter(self, requests: List[generate_pb2.Request]) -> Optional["CausalLMBatch"]: - logger.info(f"filter, requests={requests}") - if len(requests) == 0: - raise ValueError("Batch must have at least one request") - if len(requests) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - input_lengths = [] - offsets = [] - token_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, r in enumerate(requests): - idx = self.requests_idx_mapping[r.id] - requests_idx_mapping[r.id] = i - keep_indices.append(idx) - - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) == tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - max_tokens = len(requests) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.offsets = offsets - self.token_offsets = token_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self + def filter(self, requests: List[generate_pb2.Request]) -> Optional["VectorizedCausalLMBatch"]: + raise NotImplementedError() @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": - logger.info(f"concatenate, batches={batches}") - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - offsets = [] - token_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if type(batch.past_key_values[0]) == tuple: - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : - ] = past_keys[:, :, -past_seq_len:, :] - else: - # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: - ] = past_keys[:, :, :, -past_seq_len:] - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = past_values[:, :, -past_seq_len:, :] - del past_values - - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) + def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch": + raise NotImplementedError() def __len__(self): return len(self.requests) -class CausalLM(Model): +class VectorizedNextTokenChooser: + def __init__( + self, + batch_size:int, + watermark=None, + temperature=None, + repetition_penalty=None, + top_k=None, + top_p=None, + typical_p=None, + do_sample=None, + seed:int=0, + device="cpu", + ): + self.batch_size=batch_size + + do_sample=self._standardize(do_sample, False) + + watermark=self._standardize(watermark, False) + if any(watermark): + raise NotImplementedError("Watermarking not implemented") + + repetition_penalty=self._standardize(repetition_penalty, 1.0) + if any([x!=1.0 for x in repetition_penalty]): + self.repetition_penalty=torch.tensor([repetition_penalty], dtype=torch.float32, device=device).unsqueeze(1) + else: + self.repetition_penalty=None + + temperature=self._standardize(temperature, 1.0) + if any([x!=1.0 for x in temperature]): + do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)] + self.temperature=torch.tensor([temperature], dtype=torch.float32, device=device).unsqueeze(1) + else: + self.temperature=None + + top_k=self._standardize(top_k, 0) + if any([x!=0 for x in top_k]): + do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)] + self.top_k=torch.tensor([top_k], dtype=torch.float32, device=device).unsqueeze(1) + else: + self.top_k=None + + + top_p=self._standardize(top_p, 1.0) + if any([x<1.0 for x in top_p]): + raise NotImplementedError("Top P not implemented") + + typical_p=self._standardize(typical_p, 1.0) + if any([x<1.0 for x in typical_p]): + raise NotImplementedError("Typical P not implemented") + + self.do_sample = any(do_sample) + if self.do_sample and not all(do_sample): + raise NotImplementedError("Mixed greedy and probabilistic sampling not supported") + + def _standardize(self, values, default): + if isinstance(values, list): + values=values.copy() + else: + values=[values]*self.batch_size + assert len(values)==self.batch_size + for i, v in enumerate(values): + if v is None: + values[i]=default + return values + + def __call__(self, input_ids, scores): + # Only process the last token + scores=scores[: -1, :] + + if self.repetition_penalty is not None: + score = torch.gather(scores, 1, input_ids) + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty) + scores.scatter_(1, input_ids, score) + + if self.temperature is not None: + scores.div_(self.temperature) + + if self.top_k is not None: + top_k = min(self.top_k, scores.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] + scores = scores.masked_fill(indices_to_remove, self.filter_value) + + # Compute logprobs + logprobs = torch.log_softmax(scores, dim=-1) + + if self.do_sample: + raise NotImplementedError() + else: + next_token_ids = torch.argmax(scores, dim=-1) + + return next_token_ids, logprobs + + @classmethod + def from_pb( + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + device: torch.device, + ) -> "VectorizedNextTokenChooser": + # TODO: Seeds are ignored + return VectorizedNextTokenChooser( + watermark=[pb_.watermark for pb_ in pb], + temperature=[pb_.temperature for pb_ in pb], + repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + top_k=[pb_.top_k for pb_ in pb], + top_p=[pb_.top_p for pb_ in pb], + typical_p=[pb_.typical_p for pb_ in pb], + do_sample=[pb_.do_sample for pb_ in pb], + seed=0, + device=device, + ) + + +class VectorizedCausalLM(Model): def __init__( self, model_id: str, @@ -457,6 +276,7 @@ class CausalLM(Model): ): if torch.cuda.is_available(): device = torch.device("cuda") + # TODO: Choose dtype (fp16?) dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: if quantize: @@ -482,7 +302,7 @@ class CausalLM(Model): else self.model.config.eos_token_id ) - super(CausalLM, self).__init__( + super().__init__( tokenizer=tokenizer, requires_padding=True, dtype=dtype, @@ -491,94 +311,58 @@ class CausalLM(Model): ) @property - def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch + def batch_type(self) -> Type[VectorizedCausalLMBatch]: + return VectorizedCausalLMBatch def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode( generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False ) - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, outputs.past_key_values - @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: CausalLMBatch - ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + self, batch: VectorizedCausalLMBatch + ) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]: + key_length=batch.max_input_length + query_length=key_length if batch.past_key_values is None else 1 - logits, past = self.forward( - batch.input_ids, - attention_mask, - batch.position_ids, - batch.past_key_values, + outputs = self.model.forward( + input_ids=batch.input_ids[:, key_length-query_length: key_length], + attention_mask=batch.attention_mask[:, : key_length], + position_ids=batch.position_ids[:, key_length-query_length: key_length], + past_key_values=batch.past_key_values, ) + # TODO: Post-processing + next_token_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1) + + # Update batch + # TODO: Why do we need all input ids? + batch.input_ids[:, key_length].copy_(next_token_ids) + batch.past_key_values=outputs.past_key_values + batch.input_lengths=[length+1 for length in batch.input_lengths] + batch.max_input_length+=1 + + # TODO: self.decode_token, offsets? + next_token_ids=next_token_ids.cpu().tolist() + next_token_texts=self.tokenizer.batch_decode(next_token_ids) + + # TODO: Vectorize some of this? - # Results generations: List[Generation] = [] - stopped = True + next_batch=None - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.offsets, - batch.token_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - offset, - token_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset, token_offset = self.decode_token( - all_input_ids[:, 0], offset, token_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, + for i, (next_token_id, next_token_text) in enumerate(zip(next_token_ids, next_token_texts)): + stopping_criterias=batch.stopping_criterias[i] + next_token_chooser=batch.next_token_choosers[i] + stop, reason = stopping_criterias( + next_token_id, next_token_text, ) - if stop: # Decode generated tokens + # TODO: Same as stopping_criteria.current_output? output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + batch.input_ids[i, -stopping_criterias.current_tokens :] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -587,67 +371,24 @@ class CausalLM(Model): seed = None generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + output_text, stopping_criterias.current_tokens, reason, seed ) else: # Keep request in the batch generated_text = None - stopped = False + next_batch = batch - # Prefill - if stopping_criteria.current_tokens == 1: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids[1:] - ).squeeze(1)[-new_input_length:-1].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None generation = Generation( - request.id, - prefill_tokens, - next_token_id_squeezed, - next_token_logprob, + batch.requests[i].id, + None, + next_token_id, + 0, next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + next_token_id in self.all_special_ids, generated_text, ) generations.append(generation) - # Update values - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - return generations, None - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 - - # Update past key values - batch.past_key_values = past - - return generations, batch + return generations, next_batch