From ebc0a7152be469c99a59dc87605a6f02318c3c4e Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 21 Nov 2023 17:43:47 +0200 Subject: [PATCH] WIP: Added custom_logits_processors API and connecting it to flow --- server/text_generation_server/cli.py | 11 ++++++- .../models/galactica.py | 2 +- .../utils/custom_logits_processors.py | 29 +++++++++++++++++++ server/text_generation_server/utils/tokens.py | 24 +++++++++++++-- 4 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 server/text_generation_server/utils/custom_logits_processors.py diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 301acb6b..c668b138 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -4,7 +4,7 @@ import typer from pathlib import Path from loguru import logger -from typing import Optional +from typing import List, Optional from enum import Enum from huggingface_hub import hf_hub_download @@ -38,6 +38,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + custom_modules: Optional[List[str]] = None, ): if sharded: assert ( @@ -65,6 +66,14 @@ def serve( diagnose=False, ) + # Import custom modules. This can be used for Custom Logits Processors, + # in which these modules CustomLogitsProcessorsManager.register_factory() in their __init__.py, + # registering themselves for custom logits processing. + from importlib import import_module + if custom_modules: + for custom_module in custom_modules: + import_module(custom_module) + # Import here after the logger is added to log potential import exceptions from text_generation_server import server from text_generation_server.tracing import setup_tracing diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index b296c96e..65f25728 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/utils/custom_logits_processors.py b/server/text_generation_server/utils/custom_logits_processors.py new file mode 100644 index 00000000..44680304 --- /dev/null +++ b/server/text_generation_server/utils/custom_logits_processors.py @@ -0,0 +1,29 @@ +from typing import Dict, Iterable +from transformers import PreTrainedTokenizerBase, LogitsWarper +import abc + + +class CustomLogitsWarperFactory(abc.ABC): + def __init__(self, name): + self.name = name + + @abc.abstractmethod + def create_warper(self, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper: + raise NotImplementedError + + +class CustomLogitsProcessorsManager: + processors: Dict[str, CustomLogitsWarperFactory] = {} + + @classmethod + def register_factory(cls, factory: CustomLogitsWarperFactory): + """Register a factory for a custom warper. This should be called by library developers in the __init__.py of their custom warper module, + and the module should be included in the custom_modules command line argument of the text-generation-server CLI.""" + cls.processors[factory.name] = factory + + @classmethod + def create_warper(cls, name: str, tokenizer: PreTrainedTokenizerBase, params: Iterable[str]) -> LogitsWarper: + """Create a custom warper by name.""" + if name not in cls.processors: + raise ValueError(f"Unknown warper {name}. Known warpers: {', '.join(cls.processors.keys())}") + return cls.processors[name].create_warper(tokenizer, params) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0ff07417..bd368b8c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,6 @@ import re -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, NamedTuple, Optional, Tuple +from text_generation_server.utils.custom_logits_processors import CustomLogitsProcessorsManager import torch from text_generation_server.pb import generate_pb2 @@ -16,6 +17,9 @@ from text_generation_server.utils.logits_process import ( from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +class LogitsProcessorParams(NamedTuple): + name: str + params: List[str] class NextTokenChooser: def __init__( @@ -29,6 +33,8 @@ class NextTokenChooser: do_sample=False, seed=0, device="cpu", + logits_processors_params: Optional[List[LogitsProcessorParams]] = None, + tokenizer: PreTrainedTokenizerBase=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -51,7 +57,10 @@ class NextTokenChooser: ) else: self.static_warper = None - + if logits_processors_params: + self.custom_warpers = [CustomLogitsProcessorsManager.create_warper(name, params, tokenizer) for name, params in logits_processors_params] + else: + self.custom_warpers = None sampling = do_sample or has_warpers self.choice = Sampling(seed, device) if sampling else Greedy() @@ -60,7 +69,9 @@ class NextTokenChooser: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) - + if self.custom_warpers is not None: + for warper in self.custom_warpers: + scores = warper(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) else: @@ -75,7 +86,12 @@ class NextTokenChooser: cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": + if pb.logits_processors: + processors_params = [LogitsProcessorParams(name, params) for name, params in pb.logits_processors] + else: + processors_params = None return NextTokenChooser( watermark=pb.watermark, temperature=pb.temperature, @@ -86,6 +102,8 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, + logits_processors_params=processors_params, + tokenizer=tokenizer, )