mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
WIP: Added custom_logits_processors API and connecting it to flow
This commit is contained in:
parent
4ca2c5c945
commit
ebc0a7152b
@ -4,7 +4,7 @@ import typer
|
||||
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@ -38,6 +38,7 @@ def serve(
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
otlp_endpoint: Optional[str] = None,
|
||||
custom_modules: Optional[List[str]] = None,
|
||||
):
|
||||
if sharded:
|
||||
assert (
|
||||
@ -65,6 +66,14 @@ def serve(
|
||||
diagnose=False,
|
||||
)
|
||||
|
||||
# Import custom modules. This can be used for Custom Logits Processors,
|
||||
# in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py,
|
||||
# registering themselves for custom logits processing.
|
||||
from importlib import import_module
|
||||
if custom_modules:
|
||||
for custom_module in custom_modules:
|
||||
import_module(custom_module)
|
||||
|
||||
# Import here after the logger is added to log potential import exceptions
|
||||
from text_generation_server import server
|
||||
from text_generation_server.tracing import setup_tracing
|
||||
|
@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -0,0 +1,29 @@
|
||||
from typing import Dict, Iterable
|
||||
from transformers import PreTrainedTokenizerBase, LogitsWarper
|
||||
import abc
|
||||
|
||||
|
||||
class CustomLogitsWarperFactory(abc.ABC):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CustomLogitsProcessorsManager:
|
||||
processors: Dict[str, CustomLogitsWarperFactory] = {}
|
||||
|
||||
@classmethod
|
||||
def register_factory(cls, factory: CustomLogitsWarperFactory):
|
||||
"""Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module,
|
||||
and the module should be included in the custom_modules command line argument of the text-generation-server CLI."""
|
||||
cls.processors[factory.name] = factory
|
||||
|
||||
@classmethod
|
||||
def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
|
||||
"""Create a custom warper by name."""
|
||||
if name not in cls.processors:
|
||||
raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}")
|
||||
return cls.processors[name].create_warper(tokenizer, params)
|
@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, NamedTuple, Optional, Tuple
|
||||
from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager
|
||||
|
||||
import torch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
@ -16,6 +17,9 @@ from text_generation_server.utils.logits_process import (
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
class LogitsProcessorParams(NamedTuple):
|
||||
name: str
|
||||
params: List[str]
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
@ -29,6 +33,8 @@ class NextTokenChooser:
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
logits_processors_params: Optional[List[LogitsProcessorParams]] = None,
|
||||
tokenizer: PreTrainedTokenizerBase=None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -51,7 +57,10 @@ class NextTokenChooser:
|
||||
)
|
||||
else:
|
||||
self.static_warper = None
|
||||
|
||||
if logits_processors_params:
|
||||
self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params]
|
||||
else:
|
||||
self.custom_warpers = None
|
||||
sampling = do_sample or has_warpers
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
||||
@ -60,7 +69,9 @@ class NextTokenChooser:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
|
||||
if self.custom_warpers is not None:
|
||||
for warper in self.custom_warpers:
|
||||
scores = warper(input_ids, scores)
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
else:
|
||||
@ -75,7 +86,12 @@ class NextTokenChooser:
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "NextTokenChooser":
|
||||
if pb.logits_processors:
|
||||
processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors]
|
||||
else:
|
||||
processors_params = None
|
||||
return NextTokenChooser(
|
||||
watermark=pb.watermark,
|
||||
temperature=pb.temperature,
|
||||
@ -86,6 +102,8 @@ class NextTokenChooser:
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
logits_processors_params=processors_params,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user