WIP: Added custom_logits_processors API and connecting it to flow

This commit is contained in:
Noam Gat 2023-11-21 17:43:47 +02:00
parent 4ca2c5c945
commit ebc0a7152b
4 changed files with 61 additions and 5 deletions

View File

@ -4,7 +4,7 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import List, Optional
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -38,6 +38,7 @@ def serve(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,
custom_modules: Optional[List[str]] = None,
): ):
if sharded: if sharded:
assert ( assert (
@ -65,6 +66,14 @@ def serve(
diagnose=False, diagnose=False,
) )
# Import custom modules. This can be used for Custom Logits Processors,
# in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py,
# registering themselves for custom logits processing.
from importlib import import_module
if custom_modules:
for custom_module in custom_modules:
import_module(custom_module)
# Import here after the logger is added to log potential import exceptions # Import here after the logger is added to log potential import exceptions
from text_generation_server import server from text_generation_server import server
from text_generation_server.tracing import setup_tracing from text_generation_server.tracing import setup_tracing

View File

@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -0,0 +1,29 @@
from typing import Dict, Iterable
from transformers import PreTrainedTokenizerBase, LogitsWarper
import abc
class CustomLogitsWarperFactory(abc.ABC):
def __init__(self, name):
self.name = name
@abc.abstractmethod
def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
raise NotImplementedError
class CustomLogitsProcessorsManager:
processors: Dict[str, CustomLogitsWarperFactory] = {}
@classmethod
def register_factory(cls, factory: CustomLogitsWarperFactory):
"""Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module,
and the module should be included in the custom_modules command line argument of the text-generation-server CLI."""
cls.processors[factory.name] = factory
@classmethod
def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
"""Create a custom warper by name."""
if name not in cls.processors:
raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}")
return cls.processors[name].create_warper(tokenizer, params)

View File

@ -1,5 +1,6 @@
import re import re
from typing import Callable, List, Optional, Tuple from typing import Callable, List, NamedTuple, Optional, Tuple
from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@ -16,6 +17,9 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class LogitsProcessorParams(NamedTuple):
name: str
params: List[str]
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
@ -29,6 +33,8 @@ class NextTokenChooser:
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
logits_processors_params: Optional[List[LogitsProcessorParams]] = None,
tokenizer: PreTrainedTokenizerBase=None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -51,7 +57,10 @@ class NextTokenChooser:
) )
else: else:
self.static_warper = None self.static_warper = None
if logits_processors_params:
self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params]
else:
self.custom_warpers = None
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
@ -60,7 +69,9 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.custom_warpers is not None:
for warper in self.custom_warpers:
scores = warper(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
else: else:
@ -75,7 +86,12 @@ class NextTokenChooser:
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
if pb.logits_processors:
processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors]
else:
processors_params = None
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
temperature=pb.temperature, temperature=pb.temperature,
@ -86,6 +102,8 @@ class NextTokenChooser:
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,
logits_processors_params=processors_params,
tokenizer=tokenizer,
) )