mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
import inspect
|
||
import torch
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Tuple, Optional, TypeVar, Type
|
||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||
|
||
from text_generation_server.models.types import Batch, Generation
|
||
from text_generation_server.utils.speculate import get_speculate
|
||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||
|
||
B = TypeVar("B", bound=Batch)
|
||
|
||
|
||
class Model(ABC):
|
||
def __init__(
|
||
self,
|
||
model: torch.nn.Module,
|
||
tokenizer: PreTrainedTokenizerBase,
|
||
requires_padding: bool,
|
||
dtype: torch.dtype,
|
||
device: torch.device,
|
||
rank: int = 0,
|
||
world_size: int = 1,
|
||
sliding_window: Optional[int] = None,
|
||
speculate: Optional[int] = None,
|
||
):
|
||
self.model = model.eval()
|
||
self.tokenizer = tokenizer
|
||
|
||
# all_special_ids is not set correctly if the rust tokenizer is unpacked
|
||
# TODO report this to transformers.
|
||
other_special_ids = {
|
||
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
|
||
}
|
||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||
self.all_special_ids.update(other_special_ids)
|
||
self.requires_padding = requires_padding
|
||
self.dtype = dtype
|
||
self.device = device
|
||
self.rank = rank
|
||
self.world_size = world_size
|
||
self.sliding_window = sliding_window if sliding_window != -1 else None
|
||
|
||
if speculate is None:
|
||
speculate = get_speculate()
|
||
self.speculate = speculate
|
||
|
||
self.has_position_ids = (
|
||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||
is not None
|
||
)
|
||
|
||
self.check_initialized()
|
||
|
||
@property
|
||
def info(self) -> InfoResponse:
|
||
if self.requires_padding and self.sliding_window is not None:
|
||
raise NotImplementedError("sliding_window is not implemented with padding")
|
||
|
||
return InfoResponse(
|
||
requires_padding=self.requires_padding,
|
||
dtype=str(self.dtype),
|
||
device_type=self.device.type,
|
||
window_size=self.sliding_window,
|
||
speculate=self.speculate,
|
||
)
|
||
|
||
@property
|
||
@abstractmethod
|
||
def batch_type(self) -> Type[B]:
|
||
raise NotImplementedError
|
||
|
||
@abstractmethod
|
||
def generate_token(
|
||
self, batch: B
|
||
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
||
raise NotImplementedError
|
||
|
||
def warmup(self, batch: B) -> Optional[int]:
|
||
self.generate_token(batch)
|
||
return None
|
||
|
||
def decode_token(
|
||
self,
|
||
all_input_ids: List[int],
|
||
prefix_offset: int = 0,
|
||
read_offset: int = 0,
|
||
skip_special_tokens: bool = False,
|
||
) -> Tuple[str, int, int]:
|
||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||
|
||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||
# which decide to add a space or not depending on the surrounding ids.
|
||
prefix_text = self.tokenizer.decode(
|
||
all_input_ids[prefix_offset:read_offset],
|
||
skip_special_tokens=skip_special_tokens,
|
||
)
|
||
new_text = self.tokenizer.decode(
|
||
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
||
)
|
||
|
||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||
# from byte fallback tokenization.
|
||
# If it's in the middle, it's probably a real invalid id generated
|
||
# by the model
|
||
new_text = new_text[len(prefix_text) :]
|
||
return new_text, read_offset, len(all_input_ids)
|
||
else:
|
||
return "", prefix_offset, read_offset
|
||
|
||
def check_initialized(self):
|
||
uninitialized_parameters = []
|
||
for n, p in self.model.named_parameters():
|
||
if p.data.device == torch.device("meta"):
|
||
uninitialized_parameters.append(n)
|
||
if uninitialized_parameters:
|
||
raise RuntimeError(
|
||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||
)
|