From 4716bd51ad94b88c36c71329e31b67abf909d26c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 18 Sep 2024 12:41:21 +0200 Subject: [PATCH] Adding the assert. --- clients/python/text_generation/types.py | 3 +- flake.nix | 28 +++ ...t_flash_llama_completion_stream_usage.json | 206 ++++++++++++++++++ .../models/test_completion_prompts.py | 20 +- nix/client.nix | 2 +- nix/server.nix | 2 + 6 files changed, 253 insertions(+), 8 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index e36dd470..40852d9d 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -168,7 +168,7 @@ class ChatCompletionComplete(BaseModel): # Log probabilities for the chat completion logprobs: Optional[Any] # Reason for completion - finish_reason: str + finish_reason: Optional[str] # Usage details of the chat completion usage: Optional[Any] = None @@ -191,6 +191,7 @@ class ChatCompletionChunk(BaseModel): model: str system_fingerprint: str choices: List[Choice] + usage: Optional[Any] class Parameters(BaseModel): diff --git a/flake.nix b/flake.nix index 674db0d8..07348e74 100644 --- a/flake.nix +++ b/flake.nix @@ -70,6 +70,34 @@ client = pkgs.python3.pkgs.callPackage ./nix/client.nix { }; in { + checks = { + rust = + with pkgs; + rustPlatform.buildRustPackage { + name = "rust-checks"; + src = ./.; + cargoLock = { + lockFile = ./Cargo.lock; + }; + buildInputs = [ openssl.dev ]; + nativeBuildInputs = [ + clippy + pkg-config + protobuf + python3 + rustfmt + ]; + buildPhase = '' + cargo check + ''; + checkPhase = '' + cargo fmt -- --check + cargo test -j $NIX_BUILD_CORES + cargo clippy + ''; + installPhase = "touch $out"; + }; + }; formatter = pkgs.nixfmt-rfc-style; devShells = with pkgs; rec { default = pure; diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json new file mode 100644 index 00000000..8c7be4cb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json @@ -0,0 +1,206 @@ +[ + { + "choices": [ + { + "delta": { + "content": "**", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "Deep", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": " Learning", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": ":", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": " An", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": " Overview", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "**\n", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "================================", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "=====", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "\n\n", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 40, + "total_tokens": 50 + } + } +] diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 243388a2..21ebc915 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -3,7 +3,7 @@ import requests import json from aiohttp import ClientSession -from text_generation.types import Completion, ChatComplete +from text_generation.types import Completion, ChatCompletionChunk @pytest.fixture(scope="module") @@ -68,6 +68,7 @@ async def test_flash_llama_completion_stream_usage( } string = "" chunks = [] + is_final = False async with ClientSession(headers=flash_llama_completion.headers) as session: async with session.post(url, json=request) as response: # iterate over the stream @@ -84,17 +85,24 @@ async def test_flash_llama_completion_stream_usage( chunk = [json.loads(c) for c in chunk] for c in chunk: - chunks.append(ChatComplete(**c)) + chunks.append(ChatCompletionChunk(**c)) assert "choices" in c if len(c["choices"]) == 1: index = c["choices"][0]["index"] assert index == 0 - string += c["choices"][0]["text"] - elif len(c["choices"]) == 0: - assert c["usage"] is not None + string += c["choices"][0]["delta"]["content"] + + has_usage = c["usage"] is not None + assert not is_final + if has_usage: + is_final = True else: raise RuntimeError("Expected different payload") - assert string == "" + assert is_final + assert ( + string + == "**Deep Learning: An Overview**\n=====================================\n\n" + ) assert chunks == response_snapshot diff --git a/nix/client.nix b/nix/client.nix index a6a375b6..d9c9e0f5 100644 --- a/nix/client.nix +++ b/nix/client.nix @@ -6,7 +6,7 @@ }: buildPythonPackage { - name = "text-generation"; + name = "text-generation-x"; src = ../clients/python; diff --git a/nix/server.nix b/nix/server.nix index cfdb3f01..5921da7f 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -21,6 +21,7 @@ loguru, mamba-ssm, marlin-kernels, + moe-kernels, opentelemetry-api, opentelemetry-exporter-otlp, opentelemetry-instrumentation-grpc, @@ -88,6 +89,7 @@ buildPythonPackage { loguru mamba-ssm marlin-kernels + moe-kernels opentelemetry-api opentelemetry-exporter-otlp opentelemetry-instrumentation-grpc