feat: bump outlines and pydantic logic

This commit is contained in:
drbh 2024-03-18 21:00:23 +00:00
parent 03003d1eaf
commit 09f2e8ed13
10 changed files with 123 additions and 105 deletions

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, field_validator
from typing import Optional, List, Union, Any from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -32,7 +32,7 @@ class Message(BaseModel):
# Role of the message sender # Role of the message sender
role: str role: str
# Content of the message # Content of the message
content: Optional[str] content: Optional[str] = None
# Optional name of the message sender # Optional name of the message sender
name: Optional[str] = None name: Optional[str] = None
# Tool calls associated with the chat completion # Tool calls associated with the chat completion
@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
# Reason for completion # Reason for completion
finish_reason: str finish_reason: str
# Usage details of the chat completion # Usage details of the chat completion
usage: Any usage: Optional[Any] = None
class Function(BaseModel): class Function(BaseModel):
@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel): class ChoiceDelta(BaseModel):
role: str role: str
content: Optional[str] content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall] tool_calls: Optional[ChoiceDeltaToolCall]
@ -176,7 +176,7 @@ class Parameters(BaseModel):
# grammar to use for generation # grammar to use for generation
grammar: Optional[Grammar] = None grammar: Optional[Grammar] = None
@validator("best_of") @field_validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
if field_value is not None: if field_value is not None:
if field_value <= 0: if field_value <= 0:
@ -195,55 +195,55 @@ class Parameters(BaseModel):
return field_value return field_value
@validator("repetition_penalty") @field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v): def valid_repetition_penalty(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive") raise ValidationError("`repetition_penalty` must be strictly positive")
return v return v
@validator("seed") @field_validator("seed")
def valid_seed(cls, v): def valid_seed(cls, v):
if v is not None and v < 0: if v is not None and v < 0:
raise ValidationError("`seed` must be positive") raise ValidationError("`seed` must be positive")
return v return v
@validator("temperature") @field_validator("temperature")
def valid_temp(cls, v): def valid_temp(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive") raise ValidationError("`temperature` must be strictly positive")
return v return v
@validator("top_k") @field_validator("top_k")
def valid_top_k(cls, v): def valid_top_k(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive") raise ValidationError("`top_k` must be strictly positive")
return v return v
@validator("top_p") @field_validator("top_p")
def valid_top_p(cls, v): def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0): if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0") raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v return v
@validator("truncate") @field_validator("truncate")
def valid_truncate(cls, v): def valid_truncate(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive") raise ValidationError("`truncate` must be strictly positive")
return v return v
@validator("typical_p") @field_validator("typical_p")
def valid_typical_p(cls, v): def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0): if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0") raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v return v
@validator("top_n_tokens") @field_validator("top_n_tokens")
def valid_top_n_tokens(cls, v): def valid_top_n_tokens(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive") raise ValidationError("`top_n_tokens` must be strictly positive")
return v return v
@validator("grammar") @field_validator("grammar")
def valid_grammar(cls, v): def valid_grammar(cls, v):
if v is not None: if v is not None:
if v.type == GrammarType.Regex and not v.value: if v.type == GrammarType.Regex and not v.value:
@ -261,15 +261,15 @@ class Request(BaseModel):
# Whether to stream output tokens # Whether to stream output tokens
stream: bool = False stream: bool = False
@validator("inputs") @field_validator("inputs")
def valid_input(cls, v): def valid_input(cls, v):
if not v: if not v:
raise ValidationError("`inputs` cannot be empty") raise ValidationError("`inputs` cannot be empty")
return v return v
@validator("stream") @field_validator("stream")
def valid_best_of_stream(cls, field_value, values): def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"] parameters = values.data["parameters"]
if ( if (
parameters is not None parameters is not None
and parameters.best_of is not None and parameters.best_of is not None

View File

@ -25,6 +25,7 @@ from text_generation.types import (
Grammar, Grammar,
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension):
exclude=None, exclude=None,
matcher=None, matcher=None,
): ):
if isinstance(data, Response): if (
data = data.dict() isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
):
data = data.model_dump()
if isinstance(data, List): if isinstance(data, List):
data = [d.dict() for d in data] data = [d.model_dump() for d in data]
data = self._filter( data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher data=data, depth=0, path=(), exclude=exclude, matcher=matcher

View File

@ -13,7 +13,7 @@
"usage": null "usage": null
} }
], ],
"created": 1708957015, "created": 1710795556,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -8,7 +8,8 @@
"content": null, "content": null,
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": [
{
"function": { "function": {
"description": null, "description": null,
"name": "tools", "name": "tools",
@ -21,11 +22,12 @@
"id": 0, "id": 0,
"type": "function" "type": "function"
} }
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1709079417, "created": 1710795556,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -8,7 +8,8 @@
"content": null, "content": null,
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": [
{
"function": { "function": {
"description": null, "description": null,
"name": "tools", "name": "tools",
@ -21,11 +22,12 @@
"id": 0, "id": 0,
"type": "function" "type": "function"
} }
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1709079492, "created": 1710795557,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -8,7 +8,8 @@
"content": null, "content": null,
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": [
{
"function": { "function": {
"description": null, "description": null,
"name": "tools", "name": "tools",
@ -20,11 +21,12 @@
"id": 0, "id": 0,
"type": "function" "type": "function"
} }
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1709079493, "created": 1710795557,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -10,16 +10,16 @@
"name": null "name": null
}, },
"id": "", "id": "",
"index": 20, "index": 0,
"type": "function" "type": "function"
} }
}, },
"finish_reason": "eos_token", "finish_reason": "eos_token",
"index": 20, "index": 0,
"logprobs": null "logprobs": null
} }
], ],
"created": 1709087088, "created": 1710795499,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -98,7 +98,7 @@ async def test_flash_llama_grammar_no_tools(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip # @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -119,7 +119,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == { assert response.choices[0].message.tool_calls == [
{
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "tools",
@ -132,10 +133,11 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"id": 0, "id": 0,
"type": "function", "type": "function",
} }
]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip # @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
@ -159,7 +161,8 @@ async def test_flash_llama_grammar_tools_auto(
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == { assert response.choices[0].message.tool_calls == [
{
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "tools",
@ -172,10 +175,11 @@ async def test_flash_llama_grammar_tools_auto(
"id": 0, "id": 0,
"type": "function", "type": "function",
} }
]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip # @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
@ -199,7 +203,8 @@ async def test_flash_llama_grammar_tools_choice(
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == { assert response.choices[0].message.tool_calls == [
{
"id": 0, "id": 0,
"type": "function", "type": "function",
"function": { "function": {
@ -208,10 +213,11 @@ async def test_flash_llama_grammar_tools_choice(
"parameters": {"format": "celsius", "location": "New York, NY"}, "parameters": {"format": "celsius", "location": "New York, NY"},
}, },
} }
]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip # @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(

View File

@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines= { version = "^0.0.27", optional = true } outlines= { version = "0.0.36", optional = true }
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]

View File

@ -6,7 +6,7 @@ from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, DefaultDict from typing import List, Optional, DefaultDict
import time import time
@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _cached_compile_fsm(grammar_type, schema, tokenizer): def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)