WIP: Added custom_logits_processors API and connecting it to flow

This commit is contained in:
Noam Gat 2023-11-21 17:43:47 +02:00
parent 4ca2c5c945
commit ebc0a7152b
4 changed files with 61 additions and 5 deletions

View File

@ -4,7 +4,7 @@ import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from typing import List, Optional
from enum import Enum
from huggingface_hub import hf_hub_download
@ -38,6 +38,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
custom_modules: Optional[List[str]] = None,
):
if sharded:
assert (
@ -65,6 +66,14 @@ def serve(
diagnose=False,
)
# Import custom modules. This can be used for Custom Logits Processors,
# in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py,
# registering themselves for custom logits processing.
from importlib import import_module
if custom_modules:
for custom_module in custom_modules:
import_module(custom_module)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import server
from text_generation_server.tracing import setup_tracing

View File

@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -0,0 +1,29 @@
from typing import Dict, Iterable
from transformers import PreTrainedTokenizerBase, LogitsWarper
import abc
class CustomLogitsWarperFactory(abc.ABC):
def __init__(self, name):
self.name = name
@abc.abstractmethod
def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
raise NotImplementedError
class CustomLogitsProcessorsManager:
processors: Dict[str, CustomLogitsWarperFactory] = {}
@classmethod
def register_factory(cls, factory: CustomLogitsWarperFactory):
"""Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module,
and the module should be included in the custom_modules command line argument of the text-generation-server CLI."""
cls.processors[factory.name] = factory
@classmethod
def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
"""Create a custom warper by name."""
if name not in cls.processors:
raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}")
return cls.processors[name].create_warper(tokenizer, params)

View File

@ -1,5 +1,6 @@
import re
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, NamedTuple, Optional, Tuple
from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager
import torch
from text_generation_server.pb import generate_pb2
@ -16,6 +17,9 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class LogitsProcessorParams(NamedTuple):
name: str
params: List[str]
class NextTokenChooser:
def __init__(
@ -29,6 +33,8 @@ class NextTokenChooser:
do_sample=False,
seed=0,
device="cpu",
logits_processors_params: Optional[List[LogitsProcessorParams]] = None,
tokenizer: PreTrainedTokenizerBase=None,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -51,7 +57,10 @@ class NextTokenChooser:
)
else:
self.static_warper = None
if logits_processors_params:
self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params]
else:
self.custom_warpers = None
sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
@ -60,7 +69,9 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.custom_warpers is not None:
for warper in self.custom_warpers:
scores = warper(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
else:
@ -75,7 +86,12 @@ class NextTokenChooser:
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser":
if pb.logits_processors:
processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors]
else:
processors_params = None
return NextTokenChooser(
watermark=pb.watermark,
temperature=pb.temperature,
@ -86,6 +102,8 @@ class NextTokenChooser:
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
logits_processors_params=processors_params,
tokenizer=tokenizer,
)