diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 4a308cef..b88e3b42 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -32,7 +32,7 @@ class Message(BaseModel): # Role of the message sender role: str # Content of the message - content: Optional[str] + content: Optional[str] = None # Optional name of the message sender name: Optional[str] = None # Tool calls associated with the chat completion @@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel): # Reason for completion finish_reason: str # Usage details of the chat completion - usage: Any + usage: Optional[Any] = None class Function(BaseModel): @@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str - content: Optional[str] + content: Optional[str] = None tool_calls: Optional[ChoiceDeltaToolCall] @@ -176,7 +176,7 @@ class Parameters(BaseModel): # grammar to use for generation grammar: Optional[Grammar] = None - @validator("best_of") + @field_validator("best_of") def valid_best_of(cls, field_value, values): if field_value is not None: if field_value <= 0: @@ -195,55 +195,55 @@ class Parameters(BaseModel): return field_value - @validator("repetition_penalty") + @field_validator("repetition_penalty") def valid_repetition_penalty(cls, v): if v is not None and v <= 0: raise ValidationError("`repetition_penalty` must be strictly positive") return v - @validator("seed") + @field_validator("seed") def valid_seed(cls, v): if v is not None and v < 0: raise ValidationError("`seed` must be positive") return v - @validator("temperature") + @field_validator("temperature") def valid_temp(cls, v): if v is not None and v <= 0: raise ValidationError("`temperature` must be strictly positive") return v - @validator("top_k") + @field_validator("top_k") def valid_top_k(cls, v): if v is not None and v <= 0: raise ValidationError("`top_k` must be strictly positive") return v - @validator("top_p") + @field_validator("top_p") def valid_top_p(cls, v): if v is not None and (v <= 0 or v >= 1.0): raise ValidationError("`top_p` must be > 0.0 and < 1.0") return v - @validator("truncate") + @field_validator("truncate") def valid_truncate(cls, v): if v is not None and v <= 0: raise ValidationError("`truncate` must be strictly positive") return v - @validator("typical_p") + @field_validator("typical_p") def valid_typical_p(cls, v): if v is not None and (v <= 0 or v >= 1.0): raise ValidationError("`typical_p` must be > 0.0 and < 1.0") return v - @validator("top_n_tokens") + @field_validator("top_n_tokens") def valid_top_n_tokens(cls, v): if v is not None and v <= 0: raise ValidationError("`top_n_tokens` must be strictly positive") return v - @validator("grammar") + @field_validator("grammar") def valid_grammar(cls, v): if v is not None: if v.type == GrammarType.Regex and not v.value: @@ -261,15 +261,15 @@ class Request(BaseModel): # Whether to stream output tokens stream: bool = False - @validator("inputs") + @field_validator("inputs") def valid_input(cls, v): if not v: raise ValidationError("`inputs` cannot be empty") return v - @validator("stream") + @field_validator("stream") def valid_best_of_stream(cls, field_value, values): - parameters = values["parameters"] + parameters = values.data["parameters"] if ( parameters is not None and parameters.best_of is not None diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 96cf43ad..32bf4e54 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -25,6 +25,7 @@ from text_generation.types import ( Grammar, ChatComplete, ChatCompletionChunk, + ChatCompletionComplete, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) @@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension): exclude=None, matcher=None, ): - if isinstance(data, Response): - data = data.dict() + if ( + isinstance(data, Response) + or isinstance(data, ChatComplete) + or isinstance(data, ChatCompletionChunk) + or isinstance(data, ChatCompletionComplete) + ): + data = data.model_dump() if isinstance(data, List): - data = [d.dict() for d in data] + data = [d.model_dump() for d in data] data = self._filter( data=data, depth=0, path=(), exclude=exclude, matcher=matcher diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json index c2cde431..463c6ce0 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json @@ -13,7 +13,7 @@ "usage": null } ], - "created": 1708957015, + "created": 1710795556, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index 7a6c010a..7cdc4265 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -8,24 +8,26 @@ "content": null, "name": null, "role": "assistant", - "tool_calls": { - "function": { - "description": null, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14 - } - }, - "id": 0, - "type": "function" - } + "tool_calls": [ + { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14 + } + }, + "id": 0, + "type": "function" + } + ] }, "usage": null } ], - "created": 1709079417, + "created": 1710795556, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json index c5561e75..dc59e919 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -8,24 +8,26 @@ "content": null, "name": null, "role": "assistant", - "tool_calls": { - "function": { - "description": null, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14 - } - }, - "id": 0, - "type": "function" - } + "tool_calls": [ + { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14 + } + }, + "id": 0, + "type": "function" + } + ] }, "usage": null } ], - "created": 1709079492, + "created": 1710795557, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json index e03a5511..e5e8e690 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -8,23 +8,25 @@ "content": null, "name": null, "role": "assistant", - "tool_calls": { - "function": { - "description": null, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY" - } - }, - "id": 0, - "type": "function" - } + "tool_calls": [ + { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY" + } + }, + "id": 0, + "type": "function" + } + ] }, "usage": null } ], - "created": 1709079493, + "created": 1710795557, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index ceec31d9..6eb5fe0d 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -10,16 +10,16 @@ "name": null }, "id": "", - "index": 20, + "index": 0, "type": "function" } }, "finish_reason": "eos_token", - "index": 20, + "index": 0, "logprobs": null } ], - "created": 1709087088, + "created": 1710795499, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 21bcbb52..9d244b3c 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -98,7 +98,7 @@ async def test_flash_llama_grammar_no_tools( assert response == response_snapshot -@pytest.mark.skip +# @pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -119,23 +119,25 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna ], ) assert response.choices[0].message.content == None - assert response.choices[0].message.tool_calls == { - "function": { - "description": None, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14, + assert response.choices[0].message.tool_calls == [ + { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14, + }, }, - }, - "id": 0, - "type": "function", - } + "id": 0, + "type": "function", + } + ] assert response == response_snapshot -@pytest.mark.skip +# @pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -159,23 +161,25 @@ async def test_flash_llama_grammar_tools_auto( ], ) assert response.choices[0].message.content == None - assert response.choices[0].message.tool_calls == { - "function": { - "description": None, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14, + assert response.choices[0].message.tool_calls == [ + { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14, + }, }, - }, - "id": 0, - "type": "function", - } + "id": 0, + "type": "function", + } + ] assert response == response_snapshot -@pytest.mark.skip +# @pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -199,19 +203,21 @@ async def test_flash_llama_grammar_tools_choice( ], ) assert response.choices[0].message.content == None - assert response.choices[0].message.tool_calls == { - "id": 0, - "type": "function", - "function": { - "description": None, - "name": "tools", - "parameters": {"format": "celsius", "location": "New York, NY"}, - }, - } + assert response.choices[0].message.tool_calls == [ + { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "tools", + "parameters": {"format": "celsius", "location": "New York, NY"}, + }, + } + ] assert response == response_snapshot -@pytest.mark.skip +# @pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( diff --git a/server/pyproject.toml b/server/pyproject.toml index 25a31c2c..75fb201a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true } torch = { version = "^2.1.1", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" -outlines= { version = "^0.0.27", optional = true } +outlines= { version = "0.0.36", optional = true } [tool.poetry.extras] torch = ["torch"] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index cd7efec8..b4ffb863 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -6,7 +6,7 @@ from typing import Dict, Union from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.fsm import RegexFSM -from outlines.fsm.json_schema import build_regex_from_object +from outlines.fsm.json_schema import build_regex_from_schema from functools import lru_cache from typing import List, Optional, DefaultDict import time @@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor): def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: - schema = build_regex_from_object(schema) + schema = build_regex_from_schema(schema) elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer)