mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
* feat: add neuron backend * feat(neuron): add server standalone installation * feat(neuron): add server and integration tests * fix(neuron): increase ulimit when building image The base image used to compile the rust components seems to have a low ulimit for opened files, which leads to errors during compilation. * test(neuron): merge integration tests and fixtures * test: add --neuron option * review: do not use latest tag * review: remove ureq pinned version * review: --privileged should be the exception * feat: add neuron case to build ci * fix(neuron): export models from container in test fixtures The neuron tests require models to have been previously exported and cached on the hub. This is done automatically by the neuron.model fixture the first time the tests are ran for a specific version. This fixture used to export the models using optimum-neuron directly, but this package is not necessarily present on the system. Instead, it is now done through the neuron TGI itself, since it contains all the tools required to export the models. Note that since the CI runs docker in docker (dind) it does not seem possible to share a volume between the CI container and the container used to export the model. For that reason, a specific image with a modified entrypoint is built on-the-fly when a model export is required. * refactor: remove sagemaker entry-point The SageMaker image is built differently anyway. * fix(neuron): avoid using Levenshtein * test(neuron): use smaller llama model * feat(neuron): avoid installing CUDA in image * test(neuron): no error anymore when requesting too many tokens * ci: doing a precompilation step (with a different token). * test(neuron): avoid using image sha when exporting models We now manually evaluate the apparent hash of the neuron backend by combining the hash of the neuron backend directory and Dockerfile. This new hash is used to identify exported neuron models instead of the image sha. This has two benefits: - it changes less frequently (only hwen the neuron backend changes), which means less neuron models being pushed to the hub, - it can be evaluated locally, meaning that running the tests once locally will export the models before the CI uses them. * test(neuron): added a small script to prune test models --------- Co-authored-by: drbh <david.richard.holtz@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
637 lines
26 KiB
Python
637 lines
26 KiB
Python
import copy
|
||
import logging
|
||
import time
|
||
from abc import ABC
|
||
from enum import Enum
|
||
from typing import List, Optional, Tuple
|
||
|
||
import torch
|
||
from loguru import logger
|
||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
||
from transformers.generation import GenerationConfig
|
||
|
||
from optimum.neuron import NeuronModelForCausalLM
|
||
from optimum.neuron.generation import TokenSelector
|
||
|
||
from .model import get_export_kwargs_from_env
|
||
from .pb.generate_pb2 import (
|
||
Batch,
|
||
CachedBatch,
|
||
FinishReason,
|
||
GeneratedText,
|
||
Generation,
|
||
InfoResponse,
|
||
Request,
|
||
Tokens,
|
||
)
|
||
|
||
|
||
# Disable optimum-neuron warnings as it seems to block the server after a while
|
||
optimum_logger = logging.getLogger("optimum.neuron")
|
||
optimum_logger.setLevel("CRITICAL")
|
||
|
||
|
||
class Generator(ABC):
|
||
"""An abstract class to represent the workhorse behind TextGenerationService.
|
||
|
||
Ideally, it should not rely on protobuf constructs, but in a first step it does.
|
||
Implementations would typically need a model and a tokenizer to implement the Generator methods.
|
||
"""
|
||
|
||
@property
|
||
def info(self) -> InfoResponse:
|
||
"""This should simply return the expected InfoResponse"""
|
||
raise NotImplementedError
|
||
|
||
def warmup(self, batch: Batch) -> int:
|
||
"""Verify if the hardware can support the target load.
|
||
|
||
Args:
|
||
batch (`Batch`):
|
||
A batch corresponding to the maximum number of concurrent requests.
|
||
|
||
Return:
|
||
The maximum number of tokens the model supports.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
||
"""Prefill is called whenever new requests need to be added.
|
||
|
||
When this method returns successfully, a decode method will follow
|
||
with both the current and newly prefilled batch(es).
|
||
|
||
Args:
|
||
batch (`Batch`):
|
||
A batch containing the new requests.
|
||
|
||
Return:
|
||
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:
|
||
"""Decode after a prefill or another decode."""
|
||
raise NotImplementedError
|
||
|
||
def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
|
||
"""Remove requests that are not listed from the specified batch"""
|
||
raise NotImplementedError
|
||
|
||
def clear(self):
|
||
"""Remove all requests from the generator"""
|
||
raise NotImplementedError
|
||
|
||
@classmethod
|
||
def from_pretrained(cls, model_id: str, revision: Optional[str]):
|
||
"""Factory method "a la transformers" """
|
||
raise NotImplementedError
|
||
|
||
|
||
class Slot:
|
||
"""Represents a slot in a static batch"""
|
||
|
||
class State(Enum):
|
||
EMPTY = 0
|
||
PAUSE = 1
|
||
READY = 2
|
||
|
||
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
|
||
self._id = id
|
||
self._tokenizer = tokenizer
|
||
self.clear()
|
||
|
||
def clear(self):
|
||
"""Clear the slot and mark it as available."""
|
||
self._state = Slot.State.EMPTY
|
||
self._batch_id = None
|
||
self._request_id = None
|
||
self._inputs = ""
|
||
self._truncate = 0
|
||
self._generation_config = None
|
||
self._tokens = []
|
||
self._mask = torch.tensor([])
|
||
self._selector = None
|
||
self._generated_tokens = 0
|
||
self._next_text_token_start = 0
|
||
self._next_text_token_end = 0
|
||
self._generated_text = ""
|
||
self._next_text = ""
|
||
|
||
@property
|
||
def id(self) -> int:
|
||
return self._id
|
||
|
||
@property
|
||
def state(self) -> "Slot.State":
|
||
return self._state
|
||
|
||
@property
|
||
def batch_id(self) -> int:
|
||
return self._batch_id
|
||
|
||
@property
|
||
def request_id(self) -> int:
|
||
return self._request_id
|
||
|
||
@property
|
||
def cached_text(self) -> str:
|
||
return self._inputs + self._generated_text
|
||
|
||
@property
|
||
def generation_config(self) -> GenerationConfig:
|
||
return self._generation_config
|
||
|
||
@property
|
||
def generated_tokens(self) -> int:
|
||
return self._generated_tokens
|
||
|
||
def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig):
|
||
"""Assign a request to a slot.
|
||
|
||
Args:
|
||
request (`Request`):
|
||
The request to be assigned. Contains the inputs and tokens selection parameters.
|
||
generation_config (`transformers.GenerationConfig`):
|
||
The base generation config (might be modified by the request generation parameters).
|
||
"""
|
||
self._state = Slot.State.READY
|
||
self._batch_id = batch_id
|
||
self._request_id = request.id
|
||
self._inputs = request.inputs
|
||
if request.truncate:
|
||
self._truncate = request.truncate
|
||
self._generation_config = copy.deepcopy(generation_config)
|
||
# Update generation config with request parameters
|
||
self._generation_config.do_sample = request.parameters.do_sample
|
||
if self._generation_config.do_sample:
|
||
if request.parameters.temperature != 0:
|
||
self._generation_config.temperature = request.parameters.temperature
|
||
if request.parameters.top_k != 0:
|
||
self._generation_config.top_k = request.parameters.top_k
|
||
if request.parameters.top_p != 0:
|
||
self._generation_config.top_p = request.parameters.top_p
|
||
if request.parameters.typical_p != 0:
|
||
self._generation_config.typical_p = request.parameters.typical_p
|
||
if request.parameters.repetition_penalty != 0:
|
||
self._generation_config.repetition_penalty = request.parameters.repetition_penalty
|
||
self.seed = request.parameters.seed
|
||
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
|
||
self._max_new_tokens = self._generation_config.max_new_tokens
|
||
stop_strings = request.stopping_parameters.stop_sequences
|
||
if stop_strings:
|
||
self._generation_config.stop_strings = stop_strings
|
||
|
||
def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
|
||
"""Reset the slot for the next generation.
|
||
|
||
Args:
|
||
input_ids: (`torch.LongTensor`):
|
||
The new input_ids to use to generate the next token.
|
||
attention_mask: (`torch.LongTensor`):
|
||
The new attention_mask to use to generate the next token.
|
||
selector: (`optimum.neuron.generation.TokenSelector`):
|
||
An object implementing the updated token selection logic.
|
||
"""
|
||
self._tokens = input_ids.clone()
|
||
self._next_text_token_start = 0
|
||
self._next_text_token_end = torch.numel(self._tokens)
|
||
self._next_text = ""
|
||
self._mask = attention_mask.clone()
|
||
self._selector = selector
|
||
|
||
def pause(self, reset_on_pause: bool):
|
||
"""Mark the current slot as paused for generation.
|
||
|
||
Note that the KV cache for this slot will still be filled.
|
||
"""
|
||
if reset_on_pause:
|
||
# Drop the last token as it will be added back when resuming the slot
|
||
self._generated_tokens -= 1
|
||
# Since generated tokens are now part of the prefill, we need to reevaluate
|
||
# max_new_tokens for the next generation
|
||
self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens
|
||
self._state = Slot.State.PAUSE
|
||
|
||
def resume(self):
|
||
"""Mark the slot as ready for generation."""
|
||
self._state = Slot.State.READY
|
||
|
||
def _decode_next_tokens(
|
||
self,
|
||
) -> str:
|
||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||
# We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
|
||
# which decide to add a space or not depending on the surrounding ids.
|
||
new_text = self._tokenizer.decode(self._tokens[self._next_text_token_start :], skip_special_tokens=False)
|
||
if new_text.endswith("<EFBFBD>"):
|
||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||
# from byte fallback tokenization.
|
||
return ""
|
||
|
||
# Compare the generated text with the one using only the tokens producing the last one
|
||
last_text = self._tokenizer.decode(
|
||
self._tokens[self._next_text_token_start : self._next_text_token_end],
|
||
skip_special_tokens=False,
|
||
)
|
||
if len(new_text) == len(last_text):
|
||
# Nothing new was actually generated
|
||
return ""
|
||
# Return the decoded text and store its token offsets
|
||
self._next_text_token_start = self._next_text_token_end
|
||
self._next_text_token_end = torch.numel(self._tokens)
|
||
return new_text[len(last_text) :]
|
||
|
||
def append(self, next_token: int) -> str:
|
||
"""Append a new generated token to this slot
|
||
|
||
The new token is added to the list of generated tokens, which impacts
|
||
directly the generated_text and stopped property.
|
||
|
||
The new token is however not added immediately to the slot inputs: it will
|
||
be added later on when it has effectively been used to produce the next token.
|
||
|
||
Args:
|
||
next_token (`int`):
|
||
The newly generated token.
|
||
|
||
Return:
|
||
The corresponding decoded text (if any).
|
||
"""
|
||
self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
|
||
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
|
||
self._generated_tokens += 1
|
||
next_text = self._decode_next_tokens()
|
||
# Now that a new token has been generated, we can append the previous one to the generated text
|
||
self._generated_text += self._next_text
|
||
self._next_text = next_text
|
||
return next_text
|
||
|
||
def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.LongTensor:
|
||
"""Select the next token from the candidate logits.
|
||
|
||
Args:
|
||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
The sequence used as a prompt for the generation (not used in all generation modes).
|
||
logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
||
The logits corresponding to the generated tokens.
|
||
|
||
Return:
|
||
`torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
|
||
"""
|
||
return self._selector.select(input_ids, logits)[0]
|
||
|
||
@property
|
||
def stopped(self) -> bool:
|
||
# Transformers stopping criteria expects a batch of input ids
|
||
input_ids = torch.unsqueeze(self._tokens, dim=0)
|
||
return self._selector.stopping_criteria(input_ids, None)
|
||
|
||
@property
|
||
def generated_text(self) -> str:
|
||
return self._generated_text + self._next_text
|
||
|
||
@property
|
||
def next_token(self) -> int:
|
||
return None if len(self._tokens) == 0 else self._tokens[-1]
|
||
|
||
@property
|
||
def attention_mask(self) -> torch.LongTensor:
|
||
return self._mask
|
||
|
||
@property
|
||
def max_token(self) -> int:
|
||
return self._generation_config.max_length
|
||
|
||
@property
|
||
def max_new_tokens(self) -> int:
|
||
# The current value of max_new_tokens: might be different of the target max_new_tokens
|
||
# if the slot has been paused and resumed.
|
||
return self._generation_config.max_new_tokens
|
||
|
||
@property
|
||
def truncate(self) -> int:
|
||
return self._truncate
|
||
|
||
|
||
class NeuronGenerator(Generator):
|
||
"""A Generator for Neuron models."""
|
||
|
||
def __init__(
|
||
self,
|
||
model: NeuronModelForCausalLM,
|
||
tokenizer: PreTrainedTokenizerBase,
|
||
):
|
||
self.model = model
|
||
self.rebuild_cache_on_prefill = not self.model.continuous_batching
|
||
# Specify padding and truncation options for decoder-only architecture
|
||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||
tokenizer.padding_side = "left"
|
||
tokenizer.truncation_side = "left"
|
||
self.tokenizer = tokenizer
|
||
self.special_tokens = self.tokenizer.all_special_ids
|
||
self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
|
||
self.batch_id = 0
|
||
|
||
@property
|
||
def info(self) -> InfoResponse:
|
||
"""Returns the expected InfoResponse."""
|
||
dtype = getattr(self.model.config, "torch_dtype", "float32")
|
||
return InfoResponse(
|
||
requires_padding=True,
|
||
dtype=str(dtype),
|
||
device_type="xla",
|
||
)
|
||
|
||
def warmup(self, batch: Batch) -> int:
|
||
"""Verify if the hardware can support the target load.
|
||
|
||
Args:
|
||
batch (`Batch`):
|
||
A batch corresponding to the maximum number of concurrent requests.
|
||
|
||
Return:
|
||
The maximum number of tokens the model supports.
|
||
"""
|
||
# Just check that the warmup request parameters match the model capacity
|
||
batch_size = self.model.batch_size
|
||
if len(batch.requests) > batch_size:
|
||
raise ValueError(
|
||
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
|
||
)
|
||
self.prefill(batch)
|
||
self.clear()
|
||
return self.model.batch_size * self.model.max_length
|
||
|
||
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
|
||
"""Prefill new requests.
|
||
|
||
Args:
|
||
batch (`Batch`):
|
||
A batch containing the new requests.
|
||
|
||
Return:
|
||
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
|
||
"""
|
||
slots = {state: [] for state in Slot.State}
|
||
for slot in self.slots:
|
||
slots[slot.state].append(slot)
|
||
active_slots = slots[Slot.State.READY]
|
||
empty_slots = slots[Slot.State.EMPTY]
|
||
if len(empty_slots) < len(batch.requests):
|
||
raise ValueError(
|
||
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
|
||
f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
|
||
)
|
||
# Assign each request to an empty slot
|
||
logger.debug(f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)")
|
||
new_slots = []
|
||
for request in batch.requests:
|
||
slot = empty_slots.pop()
|
||
slot.assign(self.batch_id, request, self.model.generation_config)
|
||
new_slots.append(slot)
|
||
logger.debug(
|
||
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
|
||
)
|
||
if self.rebuild_cache_on_prefill:
|
||
# We will clear pending slots and prefill all slots
|
||
prefill_slots = self.slots
|
||
seq_ids = None
|
||
else:
|
||
# We only need to pass inputs for the new requests
|
||
prefill_slots = new_slots
|
||
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
|
||
# Reconstruct the full inputs (without padding) as seen by the model.
|
||
# This comprises:
|
||
# - the inputs for new requests,
|
||
# - only when rebuilding the cache, the inputs and the generated text that has already
|
||
# been cached (i.e. excluding the last generated token) for unfinished requests.
|
||
inputs = []
|
||
max_length = 0
|
||
for slot in prefill_slots:
|
||
inputs.append(slot.cached_text)
|
||
# Apply truncation, making sure we fit into static dimensions
|
||
if slot.truncate == 0:
|
||
max_length = self.model.max_length
|
||
elif slot.truncate > max_length and slot.truncate < self.model.max_length:
|
||
max_length = slot.truncate
|
||
# Tokenize with padding and truncation
|
||
padded_inputs = self.tokenizer(
|
||
inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
|
||
)
|
||
input_ids = padded_inputs.input_ids
|
||
attention_mask = padded_inputs.attention_mask
|
||
# Pause previously active slots during generation
|
||
next_tokens = []
|
||
for slot in active_slots:
|
||
slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
|
||
if self.rebuild_cache_on_prefill:
|
||
# The slot will be reset, so we need to store its next token
|
||
next_tokens.append(slot.next_token)
|
||
# Each slot must be reset with the padded inputs and masks
|
||
for i, slot in enumerate(prefill_slots):
|
||
if slot.state != slot.state.EMPTY:
|
||
if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
|
||
# Apply per-request truncation
|
||
input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
|
||
attention_mask[i, : -slot.truncate] = 0
|
||
slot_input_ids = input_ids[i : i + 1, :]
|
||
# Padded input ids are also required to set logits processors and stopping criterias
|
||
selector = TokenSelector.create(
|
||
slot_input_ids,
|
||
slot.generation_config,
|
||
self.model,
|
||
self.model.max_length,
|
||
tokenizer=self.tokenizer,
|
||
seed=slot.seed,
|
||
)
|
||
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
|
||
slot_attention_mask = attention_mask[i]
|
||
slot.reset(slot_input_ids, slot_attention_mask, selector)
|
||
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
|
||
# as they have already been generated and sent back in the last decode.
|
||
model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask, seq_ids)
|
||
logits = self.model(**model_inputs)[0]
|
||
generation, next_batch = self._generate_token(prefill_slots, self.batch_id, logits, input_ids)
|
||
self.batch_id += 1
|
||
# Reactivate previously active slots for the next decode
|
||
for i, slot in enumerate(active_slots):
|
||
slot.resume()
|
||
if self.rebuild_cache_on_prefill:
|
||
# Append back the next token
|
||
slot.append(next_tokens[i])
|
||
logger.debug("Model ready for decoding")
|
||
if next_batch is not None:
|
||
logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}")
|
||
return generation, next_batch
|
||
|
||
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
|
||
"""Decode the specified prefilled requests.
|
||
|
||
Args:
|
||
batches (`List[CachedBatch]`):
|
||
A list of previous batches containing the prefilled requests.
|
||
|
||
Return:
|
||
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
|
||
"""
|
||
# batches contains a list composed of:
|
||
# - the batch id returned by the last decode,
|
||
# - the batch id(s) returned by the last prefill(s)
|
||
# Batches are always concatenated during prefill, so we can
|
||
# just carry on with decoding. We adopt the id of the first
|
||
# batch in the list as our next batch id.
|
||
next_batch_id = batches[0].id
|
||
request_ids = []
|
||
for batch in batches:
|
||
request_ids += batch.request_ids
|
||
cleared_request_ids = []
|
||
for slot in self.slots:
|
||
if slot.state == slot.State.READY and slot.request_id not in request_ids:
|
||
cleared_request_ids.append(slot.request_id)
|
||
slot.clear()
|
||
if len(cleared_request_ids) > 0:
|
||
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
|
||
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
|
||
if len(active_slots) < len(request_ids):
|
||
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
|
||
if self.model.continuous_batching:
|
||
decode_slots = active_slots
|
||
seq_ids = torch.tensor([slot.id for slot in decode_slots])
|
||
else:
|
||
decode_slots = self.slots
|
||
seq_ids = None
|
||
# Reconstruct input_ids and attention_mask from decode slots
|
||
n_slots = len(decode_slots)
|
||
input_ids = torch.full([n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64)
|
||
max_length = 0
|
||
for slot in decode_slots:
|
||
max_length = max(max_length, slot.attention_mask.size(-1))
|
||
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
|
||
for i, slot in enumerate(decode_slots):
|
||
if slot.state != Slot.State.EMPTY:
|
||
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
|
||
input_ids[i, 0] = slot.next_token
|
||
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
|
||
model_inputs = self.model.prepare_inputs_for_decode(input_ids, attention_mask, seq_ids)
|
||
logits = self.model(**model_inputs)[0]
|
||
return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
|
||
|
||
def _generate_token(
|
||
self, slots: List[Slot], next_batch_id: int, logits: torch.Tensor, input_ids: torch.LongTensor
|
||
) -> Tuple[List[Generation], CachedBatch]:
|
||
generations = []
|
||
active_slots = False
|
||
for i, slot in enumerate(slots):
|
||
if slot.state != Slot.State.READY:
|
||
continue
|
||
request_id = slot.request_id
|
||
next_token_logits = logits[i : i + 1, -1, :]
|
||
slot_input_ids = input_ids[i : i + 1, :]
|
||
next_token = slot.select(slot_input_ids, next_token_logits)
|
||
next_token_text = slot.append(next_token)
|
||
generated_text = None
|
||
finish_reason = None
|
||
if next_token == self.tokenizer.eos_token_id:
|
||
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
|
||
elif slot.stopped:
|
||
if slot.generated_tokens == slot.max_new_tokens:
|
||
finish_reason = FinishReason.FINISH_REASON_LENGTH
|
||
else:
|
||
finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||
if finish_reason is not None:
|
||
# We must include the generated text for each finished sequence in the response
|
||
generated_text = GeneratedText(
|
||
text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason
|
||
)
|
||
logger.debug(f"Decode complete for request {request_id} with {slot.generated_tokens} tokens")
|
||
# mark the slot as available
|
||
slot.clear()
|
||
else:
|
||
active_slots = True
|
||
generations.append(
|
||
Generation(
|
||
request_id=request_id,
|
||
prefill_tokens=None,
|
||
tokens=Tokens(
|
||
ids=[next_token],
|
||
logprobs=[0],
|
||
texts=[next_token_text],
|
||
is_special=[next_token in self.special_tokens],
|
||
),
|
||
generated_text=generated_text,
|
||
)
|
||
)
|
||
batch = None
|
||
if active_slots:
|
||
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
|
||
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
|
||
batch = self._cached_batch(next_batch_id, request_ids)
|
||
else:
|
||
logger.debug("No more pending requests")
|
||
return generations, batch
|
||
|
||
def _cached_batch(self, batch_id: int, request_ids: List):
|
||
size = len(request_ids)
|
||
max_tokens = size * self.model.max_length
|
||
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)
|
||
|
||
def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
|
||
"""Remove requests that are not listed from the specified batch
|
||
|
||
Args:
|
||
batch_id (`int`):
|
||
The id of a cached batch.
|
||
keep_ids(`List[int]`):
|
||
The list of requests that must be kept.
|
||
|
||
Return:
|
||
A `CachedBatch` containing the pending requests.
|
||
"""
|
||
keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids]
|
||
self._clear(keep_slot_ids)
|
||
return self._cached_batch(batch_id, keep_request_ids)
|
||
|
||
def clear(self, batch_id: Optional[int] = None):
|
||
"""Remove a subset or all requests from the generator"""
|
||
keep_ids = []
|
||
if batch_id is not None:
|
||
keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
|
||
return self._clear(keep_ids)
|
||
|
||
def _clear(self, keep_slot_ids: List):
|
||
for slot in self.slots:
|
||
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
|
||
logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
|
||
slot.clear()
|
||
|
||
@classmethod
|
||
def from_pretrained(cls, model_id: str, revision: str = None):
|
||
"""Instantiate a NeuronGenerator.
|
||
|
||
Args:
|
||
model_id (`str`):
|
||
A hub model id or the path to a local model. This path must also contain a Tokenizer.
|
||
revision (`Optional[str]`, defaults to `None`):
|
||
The revision of the model on the HuggingFace hub.
|
||
|
||
Returns:
|
||
A NeuronGenerator.
|
||
"""
|
||
config = AutoConfig.from_pretrained(model_id)
|
||
neuron_config = getattr(config, "neuron", None)
|
||
start = time.time()
|
||
if neuron_config is None:
|
||
export_kwargs = get_export_kwargs_from_env()
|
||
logger.info(f"Exporting model to neuron with config: {export_kwargs}.")
|
||
model = NeuronModelForCausalLM.from_pretrained(
|
||
model_id, revision=revision, low_cpu_mem_usage=True, export=True, **export_kwargs
|
||
)
|
||
else:
|
||
logger.info("Loading model on neuron devices (this can take a few minutes).")
|
||
model = NeuronModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, revision=revision)
|
||
end = time.time()
|
||
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
||
return cls(model, tokenizer)
|