Fetch stuff from nix integration test for easier testing.

This commit is contained in:
Nicolas Patry 2024-09-18 12:09:26 +02:00
parent 678721bcf0
commit 162549f37f
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
5 changed files with 87 additions and 36 deletions

View File

@ -479,11 +479,11 @@
"systems": "systems_6"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1726560853,
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github"
},
"original": {
@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1726280639,
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
"lastModified": 1726626348,
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
"type": "github"
},
"original": {

View File

@ -67,31 +67,10 @@
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
in
{
checks = {
rust = with pkgs; rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
} ;
};
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = pure;
@ -106,10 +85,11 @@
test = mkShell {
buildInputs =
[
# benchmark
# launcher
# router
benchmark
launcher
router
server
client
openssl.dev
pkg-config
cargo

View File

@ -3,9 +3,7 @@ import requests
import json
from aiohttp import ClientSession
from text_generation.types import (
Completion,
)
from text_generation.types import Completion, ChatComplete
@pytest.fixture(scope="module")
@ -50,6 +48,56 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream_options": {"include_usage": True},
"stream": True,
}
string = ""
chunks = []
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatComplete(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["text"]
elif len(c["choices"]) == 0:
assert c["usage"] is not None
else:
raise RuntimeError("Expected different payload")
assert string == ""
assert chunks == response_snapshot
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(

25
nix/client.nix Normal file
View File

@ -0,0 +1,25 @@
{
buildPythonPackage,
poetry-core,
huggingface-hub,
pydantic,
}:
buildPythonPackage {
name = "text-generation";
src = ../clients/python;
pyproject = true;
build-system = [ poetry-core ];
nativeBuildInputs = [ ];
pythonRemoveDeps = [ ];
dependencies = [
huggingface-hub
pydantic
];
}

View File

@ -21,7 +21,6 @@
loguru,
mamba-ssm,
marlin-kernels,
moe-kernels,
opentelemetry-api,
opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc,
@ -89,7 +88,6 @@ buildPythonPackage {
loguru
mamba-ssm
marlin-kernels
moe-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc