diff --git a/flake.lock b/flake.lock index a6190789..8d0f4070 100644 --- a/flake.lock +++ b/flake.lock @@ -479,11 +479,11 @@ "systems": "systems_6" }, "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "lastModified": 1726560853, + "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=", "owner": "numtide", "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a", "type": "github" }, "original": { @@ -853,11 +853,11 @@ ] }, "locked": { - "lastModified": 1726280639, - "narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=", + "lastModified": 1726626348, + "narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "e9f8641c92f26fd1e076e705edb12147c384171d", + "rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 3d349ff2..674db0d8 100644 --- a/flake.nix +++ b/flake.nix @@ -67,31 +67,10 @@ ''; }; server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; + client = pkgs.python3.pkgs.callPackage ./nix/client.nix { }; in { - checks = { - rust = with pkgs; rustPlatform.buildRustPackage { - name = "rust-checks"; - src = ./.; - cargoLock = { - lockFile = ./Cargo.lock; - }; - buildInputs = [ openssl.dev ]; - nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ]; - buildPhase = '' - cargo check - ''; - checkPhase = '' - cargo fmt -- --check - cargo test -j $NIX_BUILD_CORES - cargo clippy - ''; - installPhase = "touch $out"; - } ; - }; - formatter = pkgs.nixfmt-rfc-style; - devShells = with pkgs; rec { default = pure; @@ -106,10 +85,11 @@ test = mkShell { buildInputs = [ - # benchmark - # launcher - # router + benchmark + launcher + router server + client openssl.dev pkg-config cargo diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index a3b6651d..243388a2 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -3,9 +3,7 @@ import requests import json from aiohttp import ClientSession -from text_generation.types import ( - Completion, -) +from text_generation.types import Completion, ChatComplete @pytest.fixture(scope="module") @@ -50,6 +48,56 @@ def test_flash_llama_completion_single_prompt( assert response == response_snapshot +@pytest.mark.release +async def test_flash_llama_completion_stream_usage( + flash_llama_completion, response_snapshot +): + url = f"{flash_llama_completion.base_url}/v1/chat/completions" + request = { + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is Deep Learning?", + } + ], + "max_tokens": 10, + "temperature": 0.0, + "stream_options": {"include_usage": True}, + "stream": True, + } + string = "" + chunks = [] + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: + # iterate over the stream + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] + # parse json + chunk = [json.loads(c) for c in chunk] + + for c in chunk: + chunks.append(ChatComplete(**c)) + assert "choices" in c + if len(c["choices"]) == 1: + index = c["choices"][0]["index"] + assert index == 0 + string += c["choices"][0]["text"] + elif len(c["choices"]) == 0: + assert c["usage"] is not None + else: + raise RuntimeError("Expected different payload") + assert string == "" + assert chunks == response_snapshot + + @pytest.mark.release def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): response = requests.post( diff --git a/nix/client.nix b/nix/client.nix new file mode 100644 index 00000000..a6a375b6 --- /dev/null +++ b/nix/client.nix @@ -0,0 +1,25 @@ +{ + buildPythonPackage, + poetry-core, + huggingface-hub, + pydantic, +}: + +buildPythonPackage { + name = "text-generation"; + + src = ../clients/python; + + pyproject = true; + + build-system = [ poetry-core ]; + + nativeBuildInputs = [ ]; + + pythonRemoveDeps = [ ]; + + dependencies = [ + huggingface-hub + pydantic + ]; +} diff --git a/nix/server.nix b/nix/server.nix index 5921da7f..cfdb3f01 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -21,7 +21,6 @@ loguru, mamba-ssm, marlin-kernels, - moe-kernels, opentelemetry-api, opentelemetry-exporter-otlp, opentelemetry-instrumentation-grpc, @@ -89,7 +88,6 @@ buildPythonPackage { loguru mamba-ssm marlin-kernels - moe-kernels opentelemetry-api opentelemetry-exporter-otlp opentelemetry-instrumentation-grpc