fix: remove compilation artifacts from logit processor

This commit is contained in:
drbh 2024-03-08 03:44:38 +00:00
parent 1f7be736d2
commit d031919c8a

View File

@ -3,10 +3,8 @@ import torch
from loguru import logger
from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
@ -520,19 +518,6 @@ class GrammarLogitProcessor(LogitsProcessor):
return fsm_grammar_state
return fsm.next_state(fsm_grammar_state, next_token_id)
# TODO: move grammar compilation into the router
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_adapt_tokenizer(tokenizer):