mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
WIP: Added custom_logits_processors API and connecting it to flow
This commit is contained in:
parent
4ca2c5c945
commit
ebc0a7152b
@ -4,7 +4,7 @@ import typer
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
@ -38,6 +38,7 @@ def serve(
|
|||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
custom_modules: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
assert (
|
assert (
|
||||||
@ -65,6 +66,14 @@ def serve(
|
|||||||
diagnose=False,
|
diagnose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import custom modules. This can be used for Custom Logits Processors,
|
||||||
|
# in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py,
|
||||||
|
# registering themselves for custom logits processing.
|
||||||
|
from importlib import import_module
|
||||||
|
if custom_modules:
|
||||||
|
for custom_module in custom_modules:
|
||||||
|
import_module(custom_module)
|
||||||
|
|
||||||
# Import here after the logger is added to log potential import exceptions
|
# Import here after the logger is added to log potential import exceptions
|
||||||
from text_generation_server import server
|
from text_generation_server import server
|
||||||
from text_generation_server.tracing import setup_tracing
|
from text_generation_server.tracing import setup_tracing
|
||||||
|
@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Dict, Iterable
|
||||||
|
from transformers import PreTrainedTokenizerBase, LogitsWarper
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLogitsWarperFactory(abc.ABC):
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLogitsProcessorsManager:
|
||||||
|
processors: Dict[str, CustomLogitsWarperFactory] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_factory(cls, factory: CustomLogitsWarperFactory):
|
||||||
|
"""Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module,
|
||||||
|
and the module should be included in the custom_modules command line argument of the text-generation-server CLI."""
|
||||||
|
cls.processors[factory.name] = factory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper:
|
||||||
|
"""Create a custom warper by name."""
|
||||||
|
if name not in cls.processors:
|
||||||
|
raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}")
|
||||||
|
return cls.processors[name].create_warper(tokenizer, params)
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, NamedTuple, Optional, Tuple
|
||||||
|
from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
@ -16,6 +17,9 @@ from text_generation_server.utils.logits_process import (
|
|||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
class LogitsProcessorParams(NamedTuple):
|
||||||
|
name: str
|
||||||
|
params: List[str]
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -29,6 +33,8 @@ class NextTokenChooser:
|
|||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
logits_processors_params: Optional[List[LogitsProcessorParams]] = None,
|
||||||
|
tokenizer: PreTrainedTokenizerBase=None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -51,7 +57,10 @@ class NextTokenChooser:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
if logits_processors_params:
|
||||||
|
self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params]
|
||||||
|
else:
|
||||||
|
self.custom_warpers = None
|
||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
@ -60,7 +69,9 @@ class NextTokenChooser:
|
|||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
if self.custom_warpers is not None:
|
||||||
|
for warper in self.custom_warpers:
|
||||||
|
scores = warper(input_ids, scores)
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
else:
|
else:
|
||||||
@ -75,7 +86,12 @@ class NextTokenChooser:
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
|
if pb.logits_processors:
|
||||||
|
processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors]
|
||||||
|
else:
|
||||||
|
processors_params = None
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
@ -86,6 +102,8 @@ class NextTokenChooser:
|
|||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
|
logits_processors_params=processors_params,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user