feat: bump outlines and pydantic logic

This commit is contained in:
drbh 2024-03-18 21:00:23 +00:00
parent 03003d1eaf
commit 09f2e8ed13
10 changed files with 123 additions and 105 deletions

View File

@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError
@ -32,7 +32,7 @@ class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str]
content: Optional[str] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Any
usage: Optional[Any] = None
class Function(BaseModel):
@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall]
@ -176,7 +176,7 @@ class Parameters(BaseModel):
# grammar to use for generation
grammar: Optional[Grammar] = None
@validator("best_of")
@field_validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
@ -195,55 +195,55 @@ class Parameters(BaseModel):
return field_value
@validator("repetition_penalty")
@field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed")
@field_validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v
@validator("temperature")
@field_validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
@validator("top_k")
@field_validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
@validator("top_p")
@field_validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v
@validator("truncate")
@field_validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v
@validator("typical_p")
@field_validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
@validator("top_n_tokens")
@field_validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
@validator("grammar")
@field_validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
@ -261,15 +261,15 @@ class Request(BaseModel):
# Whether to stream output tokens
stream: bool = False
@validator("inputs")
@field_validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v
@validator("stream")
@field_validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
parameters = values.data["parameters"]
if (
parameters is not None
and parameters.best_of is not None

View File

@ -25,6 +25,7 @@ from text_generation.types import (
Grammar,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension):
exclude=None,
matcher=None,
):
if isinstance(data, Response):
data = data.dict()
if (
isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
):
data = data.model_dump()
if isinstance(data, List):
data = [d.dict() for d in data]
data = [d.model_dump() for d in data]
data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher

View File

@ -13,7 +13,7 @@
"usage": null
}
],
"created": 1708957015,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -8,24 +8,26 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079417,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -8,24 +8,26 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079492,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -8,23 +8,25 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079493,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -10,16 +10,16 @@
"name": null
},
"id": "",
"index": 20,
"index": 0,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"index": 0,
"logprobs": null
}
],
"created": 1709087088,
"created": 1710795499,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -98,7 +98,7 @@ async def test_flash_llama_grammar_no_tools(
assert response == response_snapshot
@pytest.mark.skip
# @pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -119,23 +119,25 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
assert response.choices[0].message.tool_calls == [
{
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
},
"id": 0,
"type": "function",
}
"id": 0,
"type": "function",
}
]
assert response == response_snapshot
@pytest.mark.skip
# @pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
@ -159,23 +161,25 @@ async def test_flash_llama_grammar_tools_auto(
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
assert response.choices[0].message.tool_calls == [
{
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
},
"id": 0,
"type": "function",
}
"id": 0,
"type": "function",
}
]
assert response == response_snapshot
@pytest.mark.skip
# @pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
@ -199,19 +203,21 @@ async def test_flash_llama_grammar_tools_choice(
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
},
}
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip
# @pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(

View File

@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true }
torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"
outlines= { version = "^0.0.27", optional = true }
outlines= { version = "0.0.36", optional = true }
[tool.poetry.extras]
torch = ["torch"]

View File

@ -6,7 +6,7 @@ from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)