mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: bump outlines and pydantic logic
This commit is contained in:
parent
03003d1eaf
commit
09f2e8ed13
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Optional, List, Union, Any
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
@ -32,7 +32,7 @@ class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str]
|
||||
content: Optional[str] = None
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Any
|
||||
usage: Optional[Any] = None
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str]
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ class Parameters(BaseModel):
|
||||
# grammar to use for generation
|
||||
grammar: Optional[Grammar] = None
|
||||
|
||||
@validator("best_of")
|
||||
@field_validator("best_of")
|
||||
def valid_best_of(cls, field_value, values):
|
||||
if field_value is not None:
|
||||
if field_value <= 0:
|
||||
@ -195,55 +195,55 @@ class Parameters(BaseModel):
|
||||
|
||||
return field_value
|
||||
|
||||
@validator("repetition_penalty")
|
||||
@field_validator("repetition_penalty")
|
||||
def valid_repetition_penalty(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`repetition_penalty` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("seed")
|
||||
@field_validator("seed")
|
||||
def valid_seed(cls, v):
|
||||
if v is not None and v < 0:
|
||||
raise ValidationError("`seed` must be positive")
|
||||
return v
|
||||
|
||||
@validator("temperature")
|
||||
@field_validator("temperature")
|
||||
def valid_temp(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`temperature` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("top_k")
|
||||
@field_validator("top_k")
|
||||
def valid_top_k(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`top_k` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("top_p")
|
||||
@field_validator("top_p")
|
||||
def valid_top_p(cls, v):
|
||||
if v is not None and (v <= 0 or v >= 1.0):
|
||||
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
|
||||
return v
|
||||
|
||||
@validator("truncate")
|
||||
@field_validator("truncate")
|
||||
def valid_truncate(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`truncate` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("typical_p")
|
||||
@field_validator("typical_p")
|
||||
def valid_typical_p(cls, v):
|
||||
if v is not None and (v <= 0 or v >= 1.0):
|
||||
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||
return v
|
||||
|
||||
@validator("top_n_tokens")
|
||||
@field_validator("top_n_tokens")
|
||||
def valid_top_n_tokens(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||
return v
|
||||
|
||||
@validator("grammar")
|
||||
@field_validator("grammar")
|
||||
def valid_grammar(cls, v):
|
||||
if v is not None:
|
||||
if v.type == GrammarType.Regex and not v.value:
|
||||
@ -261,15 +261,15 @@ class Request(BaseModel):
|
||||
# Whether to stream output tokens
|
||||
stream: bool = False
|
||||
|
||||
@validator("inputs")
|
||||
@field_validator("inputs")
|
||||
def valid_input(cls, v):
|
||||
if not v:
|
||||
raise ValidationError("`inputs` cannot be empty")
|
||||
return v
|
||||
|
||||
@validator("stream")
|
||||
@field_validator("stream")
|
||||
def valid_best_of_stream(cls, field_value, values):
|
||||
parameters = values["parameters"]
|
||||
parameters = values.data["parameters"]
|
||||
if (
|
||||
parameters is not None
|
||||
and parameters.best_of is not None
|
||||
|
@ -25,6 +25,7 @@ from text_generation.types import (
|
||||
Grammar,
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionComplete,
|
||||
)
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
exclude=None,
|
||||
matcher=None,
|
||||
):
|
||||
if isinstance(data, Response):
|
||||
data = data.dict()
|
||||
if (
|
||||
isinstance(data, Response)
|
||||
or isinstance(data, ChatComplete)
|
||||
or isinstance(data, ChatCompletionChunk)
|
||||
or isinstance(data, ChatCompletionComplete)
|
||||
):
|
||||
data = data.model_dump()
|
||||
|
||||
if isinstance(data, List):
|
||||
data = [d.dict() for d in data]
|
||||
data = [d.model_dump() for d in data]
|
||||
|
||||
data = self._filter(
|
||||
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
|
||||
|
@ -13,7 +13,7 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1708957015,
|
||||
"created": 1710795556,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
@ -8,24 +8,26 @@
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079417,
|
||||
"created": 1710795556,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
@ -8,24 +8,26 @@
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079492,
|
||||
"created": 1710795557,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
@ -8,23 +8,25 @@
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY"
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY"
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079493,
|
||||
"created": 1710795557,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
@ -10,16 +10,16 @@
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 20,
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"finish_reason": "eos_token",
|
||||
"index": 20,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1709087088,
|
||||
"created": 1710795499,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
@ -98,7 +98,7 @@ async def test_flash_llama_grammar_no_tools(
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# @pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
@ -119,23 +119,25 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# @pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
@ -159,23 +161,25 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# @pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
@ -199,19 +203,21 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# @pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
|
@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true }
|
||||
torch = { version = "^2.1.1", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
outlines= { version = "^0.0.27", optional = true }
|
||||
outlines= { version = "0.0.36", optional = true }
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
|
@ -6,7 +6,7 @@ from typing import Dict, Union
|
||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, DefaultDict
|
||||
import time
|
||||
@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
||||
start_time = time.time()
|
||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
schema = build_regex_from_object(schema)
|
||||
schema = build_regex_from_schema(schema)
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
|
Loading…
Reference in New Issue
Block a user