import copy import logging import time from abc import ABC from enum import Enum from typing import List, Optional, Tuple import torch from loguru import logger from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase from transformers.generation import GenerationConfig from optimum.neuron import NeuronModelForCausalLM from optimum.neuron.generation import TokenSelector from .model import get_export_kwargs_from_env from .pb.generate_pb2 import ( Batch, CachedBatch, FinishReason, GeneratedText, Generation, InfoResponse, Request, Tokens, ) # Disable optimum-neuron warnings as it seems to block the server after a while optimum_logger = logging.getLogger("optimum.neuron") optimum_logger.setLevel("CRITICAL") class Generator(ABC): """An abstract class to represent the workhorse behind TextGenerationService. Ideally, it should not rely on protobuf constructs, but in a first step it does. Implementations would typically need a model and a tokenizer to implement the Generator methods. """ @property def info(self) -> InfoResponse: """This should simply return the expected InfoResponse""" raise NotImplementedError def warmup(self, batch: Batch) -> int: """Verify if the hardware can support the target load. Args: batch (`Batch`): A batch corresponding to the maximum number of concurrent requests. Return: The maximum number of tokens the model supports. """ raise NotImplementedError def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill is called whenever new requests need to be added. When this method returns successfully, a decode method will follow with both the current and newly prefilled batch(es). Args: batch (`Batch`): A batch containing the new requests. Return: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ raise NotImplementedError def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]: """Decode after a prefill or another decode.""" raise NotImplementedError def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: """Remove requests that are not listed from the specified batch""" raise NotImplementedError def clear(self): """Remove all requests from the generator""" raise NotImplementedError @classmethod def from_pretrained(cls, model_id: str, revision: Optional[str]): """Factory method "a la transformers" """ raise NotImplementedError class Slot: """Represents a slot in a static batch""" class State(Enum): EMPTY = 0 PAUSE = 1 READY = 2 def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase): self._id = id self._tokenizer = tokenizer self.clear() def clear(self): """Clear the slot and mark it as available.""" self._state = Slot.State.EMPTY self._batch_id = None self._request_id = None self._inputs = "" self._truncate = 0 self._generation_config = None self._tokens = [] self._mask = torch.tensor([]) self._selector = None self._generated_tokens = 0 self._next_text_token_start = 0 self._next_text_token_end = 0 self._generated_text = "" self._next_text = "" @property def id(self) -> int: return self._id @property def state(self) -> "Slot.State": return self._state @property def batch_id(self) -> int: return self._batch_id @property def request_id(self) -> int: return self._request_id @property def cached_text(self) -> str: return self._inputs + self._generated_text @property def generation_config(self) -> GenerationConfig: return self._generation_config @property def generated_tokens(self) -> int: return self._generated_tokens def assign( self, batch_id: int, request: Request, generation_config: GenerationConfig ): """Assign a request to a slot. Args: request (`Request`): The request to be assigned. Contains the inputs and tokens selection parameters. generation_config (`transformers.GenerationConfig`): The base generation config (might be modified by the request generation parameters). """ self._state = Slot.State.READY self._batch_id = batch_id self._request_id = request.id self._inputs = request.inputs if request.truncate: self._truncate = request.truncate self._generation_config = copy.deepcopy(generation_config) # Update generation config with request parameters self._generation_config.do_sample = request.parameters.do_sample if self._generation_config.do_sample: if request.parameters.temperature != 0: self._generation_config.temperature = request.parameters.temperature if request.parameters.top_k != 0: self._generation_config.top_k = request.parameters.top_k if request.parameters.top_p != 0: self._generation_config.top_p = request.parameters.top_p if request.parameters.typical_p != 0: self._generation_config.typical_p = request.parameters.typical_p if request.parameters.repetition_penalty != 0: self._generation_config.repetition_penalty = ( request.parameters.repetition_penalty ) self.seed = request.parameters.seed self._generation_config.max_new_tokens = ( request.stopping_parameters.max_new_tokens ) self._max_new_tokens = self._generation_config.max_new_tokens stop_strings = request.stopping_parameters.stop_sequences if stop_strings: self._generation_config.stop_strings = stop_strings def reset( self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector, ): """Reset the slot for the next generation. Args: input_ids: (`torch.LongTensor`): The new input_ids to use to generate the next token. attention_mask: (`torch.LongTensor`): The new attention_mask to use to generate the next token. selector: (`optimum.neuron.generation.TokenSelector`): An object implementing the updated token selection logic. """ self._tokens = input_ids.clone() self._next_text_token_start = 0 self._next_text_token_end = torch.numel(self._tokens) self._next_text = "" self._mask = attention_mask.clone() self._selector = selector def pause(self, reset_on_pause: bool): """Mark the current slot as paused for generation. Note that the KV cache for this slot will still be filled. """ if reset_on_pause: # Drop the last token as it will be added back when resuming the slot self._generated_tokens -= 1 # Since generated tokens are now part of the prefill, we need to reevaluate # max_new_tokens for the next generation self._generation_config.max_new_tokens = ( self._max_new_tokens - self._generated_tokens ) self._state = Slot.State.PAUSE def resume(self): """Mark the slot as ready for generation.""" self._state = Slot.State.READY def _decode_next_tokens( self, ) -> str: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode # which decide to add a space or not depending on the surrounding ids. new_text = self._tokenizer.decode( self._tokens[self._next_text_token_start :], skip_special_tokens=False ) if new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence # from byte fallback tokenization. return "" # Compare the generated text with the one using only the tokens producing the last one last_text = self._tokenizer.decode( self._tokens[self._next_text_token_start : self._next_text_token_end], skip_special_tokens=False, ) if len(new_text) == len(last_text): # Nothing new was actually generated return "" # Return the decoded text and store its token offsets self._next_text_token_start = self._next_text_token_end self._next_text_token_end = torch.numel(self._tokens) return new_text[len(last_text) :] def append(self, next_token: int) -> str: """Append a new generated token to this slot The new token is added to the list of generated tokens, which impacts directly the generated_text and stopped property. The new token is however not added immediately to the slot inputs: it will be added later on when it has effectively been used to produce the next token. Args: next_token (`int`): The newly generated token. Return: The corresponding decoded text (if any). """ self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])]) self._mask = torch.cat([self._mask, torch.LongTensor([1])]) self._generated_tokens += 1 next_text = self._decode_next_tokens() # Now that a new token has been generated, we can append the previous one to the generated text self._generated_text += self._next_text self._next_text = next_text return next_text def select( self, input_ids: torch.LongTensor, logits: torch.Tensor ) -> torch.LongTensor: """Select the next token from the candidate logits. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation (not used in all generation modes). logits (`torch.Tensor` of shape `(batch_size, sequence_length)`): The logits corresponding to the generated tokens. Return: `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token. """ return self._selector.select(input_ids, logits)[0] @property def stopped(self) -> bool: # Transformers stopping criteria expects a batch of input ids input_ids = torch.unsqueeze(self._tokens, dim=0) return self._selector.stopping_criteria(input_ids, None) @property def generated_text(self) -> str: return self._generated_text + self._next_text @property def next_token(self) -> int: return None if len(self._tokens) == 0 else self._tokens[-1] @property def attention_mask(self) -> torch.LongTensor: return self._mask @property def max_token(self) -> int: return self._generation_config.max_length @property def max_new_tokens(self) -> int: # The current value of max_new_tokens: might be different of the target max_new_tokens # if the slot has been paused and resumed. return self._generation_config.max_new_tokens @property def truncate(self) -> int: return self._truncate class NeuronGenerator(Generator): """A Generator for Neuron models.""" def __init__( self, model: NeuronModelForCausalLM, tokenizer: PreTrainedTokenizerBase, ): self.model = model self.rebuild_cache_on_prefill = not self.model.continuous_batching # Specify padding and truncation options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] self.batch_id = 0 @property def info(self) -> InfoResponse: """Returns the expected InfoResponse.""" dtype = getattr(self.model.config, "torch_dtype", "float32") return InfoResponse( requires_padding=True, dtype=str(dtype), device_type="xla", ) def warmup(self, batch: Batch) -> int: """Verify if the hardware can support the target load. Args: batch (`Batch`): A batch corresponding to the maximum number of concurrent requests. Return: The maximum number of tokens the model supports. """ # Just check that the warmup request parameters match the model capacity batch_size = self.model.batch_size if len(batch.requests) > batch_size: raise ValueError( f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." ) self.prefill(batch) self.clear() return self.model.batch_size * self.model.max_length def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill new requests. Args: batch (`Batch`): A batch containing the new requests. Return: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ slots = {state: [] for state in Slot.State} for slot in self.slots: slots[slot.state].append(slot) active_slots = slots[Slot.State.READY] empty_slots = slots[Slot.State.EMPTY] if len(empty_slots) < len(batch.requests): raise ValueError( f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." f" Please align max_batch_size with the static batch size: {self.model.batch_size}." ) # Assign each request to an empty slot logger.debug( f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)" ) new_slots = [] for request in batch.requests: slot = empty_slots.pop() slot.assign(self.batch_id, request, self.model.generation_config) new_slots.append(slot) logger.debug( f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" ) if self.rebuild_cache_on_prefill: # We will clear pending slots and prefill all slots prefill_slots = self.slots seq_ids = None else: # We only need to pass inputs for the new requests prefill_slots = new_slots seq_ids = torch.tensor([slot.id for slot in prefill_slots]) # Reconstruct the full inputs (without padding) as seen by the model. # This comprises: # - the inputs for new requests, # - only when rebuilding the cache, the inputs and the generated text that has already # been cached (i.e. excluding the last generated token) for unfinished requests. inputs = [] max_length = 0 for slot in prefill_slots: inputs.append(slot.cached_text) # Apply truncation, making sure we fit into static dimensions if slot.truncate == 0: max_length = self.model.max_length elif slot.truncate > max_length and slot.truncate < self.model.max_length: max_length = slot.truncate # Tokenize with padding and truncation padded_inputs = self.tokenizer( inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) input_ids = padded_inputs.input_ids attention_mask = padded_inputs.attention_mask # Pause previously active slots during generation next_tokens = [] for slot in active_slots: slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) if self.rebuild_cache_on_prefill: # The slot will be reset, so we need to store its next token next_tokens.append(slot.next_token) # Each slot must be reset with the padded inputs and masks for i, slot in enumerate(prefill_slots): if slot.state != slot.state.EMPTY: if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]: # Apply per-request truncation input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id attention_mask[i, : -slot.truncate] = 0 slot_input_ids = input_ids[i : i + 1, :] # Padded input ids are also required to set logits processors and stopping criterias selector = TokenSelector.create( slot_input_ids, slot.generation_config, self.model, self.model.max_length, tokenizer=self.tokenizer, seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill( input_ids, attention_mask, seq_ids ) logits = self.model(**model_inputs)[0] generation, next_batch = self._generate_token( prefill_slots, self.batch_id, logits, input_ids ) self.batch_id += 1 # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() if self.rebuild_cache_on_prefill: # Append back the next token slot.append(next_tokens[i]) logger.debug("Model ready for decoding") if next_batch is not None: logger.debug( f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}" ) return generation, next_batch def decode( self, batches: List[CachedBatch] ) -> Tuple[List[Generation], CachedBatch]: """Decode the specified prefilled requests. Args: batches (`List[CachedBatch]`): A list of previous batches containing the prefilled requests. Return: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ # batches contains a list composed of: # - the batch id returned by the last decode, # - the batch id(s) returned by the last prefill(s) # Batches are always concatenated during prefill, so we can # just carry on with decoding. We adopt the id of the first # batch in the list as our next batch id. next_batch_id = batches[0].id request_ids = [] for batch in batches: request_ids += batch.request_ids cleared_request_ids = [] for slot in self.slots: if slot.state == slot.State.READY and slot.request_id not in request_ids: cleared_request_ids.append(slot.request_id) slot.clear() if len(cleared_request_ids) > 0: logger.info( f"Clearing slot for requests {cleared_request_ids} as they are not requested." ) active_slots = [slot for slot in self.slots if slot.state == slot.State.READY] if len(active_slots) < len(request_ids): raise ValueError( "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" ) if self.model.continuous_batching: decode_slots = active_slots seq_ids = torch.tensor([slot.id for slot in decode_slots]) else: decode_slots = self.slots seq_ids = None # Reconstruct input_ids and attention_mask from decode slots n_slots = len(decode_slots) input_ids = torch.full( [n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64 ) max_length = 0 for slot in decode_slots: max_length = max(max_length, slot.attention_mask.size(-1)) attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) for i, slot in enumerate(decode_slots): if slot.state != Slot.State.EMPTY: # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask model_inputs = self.model.prepare_inputs_for_decode( input_ids, attention_mask, seq_ids ) logits = self.model(**model_inputs)[0] return self._generate_token(decode_slots, next_batch_id, logits, input_ids) def _generate_token( self, slots: List[Slot], next_batch_id: int, logits: torch.Tensor, input_ids: torch.LongTensor, ) -> Tuple[List[Generation], CachedBatch]: generations = [] active_slots = False for i, slot in enumerate(slots): if slot.state != Slot.State.READY: continue request_id = slot.request_id next_token_logits = logits[i : i + 1, -1, :] slot_input_ids = input_ids[i : i + 1, :] next_token = slot.select(slot_input_ids, next_token_logits) next_token_text = slot.append(next_token) generated_text = None finish_reason = None if next_token == self.tokenizer.eos_token_id: finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN elif slot.stopped: if slot.generated_tokens == slot.max_new_tokens: finish_reason = FinishReason.FINISH_REASON_LENGTH else: finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE if finish_reason is not None: # We must include the generated text for each finished sequence in the response generated_text = GeneratedText( text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason, ) logger.debug( f"Decode complete for request {request_id} with {slot.generated_tokens} tokens" ) # mark the slot as available slot.clear() else: active_slots = True generations.append( Generation( request_id=request_id, prefill_tokens=None, tokens=Tokens( ids=[next_token], logprobs=[0], texts=[next_token_text], is_special=[next_token in self.special_tokens], ), generated_text=generated_text, ) ) batch = None if active_slots: # Whatever initial batch these requests came from, we always return all pending requests in a single batch request_ids = [ slot.request_id for slot in self.slots if slot.state == Slot.State.READY ] batch = self._cached_batch(next_batch_id, request_ids) else: logger.debug("No more pending requests") return generations, batch def _cached_batch(self, batch_id: int, request_ids: List): size = len(request_ids) max_tokens = size * self.model.max_length return CachedBatch( id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens ) def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch: """Remove requests that are not listed from the specified batch Args: batch_id (`int`): The id of a cached batch. keep_ids(`List[int]`): The list of requests that must be kept. Return: A `CachedBatch` containing the pending requests. """ keep_slot_ids = [ slot.id for slot in self.slots if slot.request_id in keep_request_ids ] self._clear(keep_slot_ids) return self._cached_batch(batch_id, keep_request_ids) def clear(self, batch_id: Optional[int] = None): """Remove a subset or all requests from the generator""" keep_ids = [] if batch_id is not None: keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id] return self._clear(keep_ids) def _clear(self, keep_slot_ids: List): for slot in self.slots: if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids: logger.debug(f"Removing slot {slot.id} with request {slot.request_id}") slot.clear() @classmethod def from_pretrained(cls, model_id: str, revision: str = None): """Instantiate a NeuronGenerator. Args: model_id (`str`): A hub model id or the path to a local model. This path must also contain a Tokenizer. revision (`Optional[str]`, defaults to `None`): The revision of the model on the HuggingFace hub. Returns: A NeuronGenerator. """ config = AutoConfig.from_pretrained(model_id) neuron_config = getattr(config, "neuron", None) start = time.time() if neuron_config is None: export_kwargs = get_export_kwargs_from_env() logger.info(f"Exporting model to neuron with config: {export_kwargs}.") model = NeuronModelForCausalLM.from_pretrained( model_id, revision=revision, low_cpu_mem_usage=True, export=True, **export_kwargs, ) else: logger.info( "Loading model on neuron devices (this can take a few minutes)." ) model = NeuronModelForCausalLM.from_pretrained( model_id, low_cpu_mem_usage=True, revision=revision ) end = time.time() logger.info(f"Model successfully loaded in {end - start:.2f} s.") tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) return cls(model, tokenizer)