mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fetch stuff from nix integration test for easier testing.
This commit is contained in:
parent
678721bcf0
commit
162549f37f
12
flake.lock
12
flake.lock
@ -479,11 +479,11 @@
|
||||
"systems": "systems_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"lastModified": 1726560853,
|
||||
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -853,11 +853,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1726280639,
|
||||
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
|
||||
"lastModified": 1726626348,
|
||||
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
|
||||
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
30
flake.nix
30
flake.nix
@ -67,31 +67,10 @@
|
||||
'';
|
||||
};
|
||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
|
||||
in
|
||||
{
|
||||
checks = {
|
||||
rust = with pkgs; rustPlatform.buildRustPackage {
|
||||
name = "rust-checks";
|
||||
src = ./.;
|
||||
cargoLock = {
|
||||
lockFile = ./Cargo.lock;
|
||||
};
|
||||
buildInputs = [ openssl.dev ];
|
||||
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
|
||||
buildPhase = ''
|
||||
cargo check
|
||||
'';
|
||||
checkPhase = ''
|
||||
cargo fmt -- --check
|
||||
cargo test -j $NIX_BUILD_CORES
|
||||
cargo clippy
|
||||
'';
|
||||
installPhase = "touch $out";
|
||||
} ;
|
||||
};
|
||||
|
||||
formatter = pkgs.nixfmt-rfc-style;
|
||||
|
||||
devShells = with pkgs; rec {
|
||||
default = pure;
|
||||
|
||||
@ -106,10 +85,11 @@
|
||||
test = mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
# benchmark
|
||||
# launcher
|
||||
# router
|
||||
benchmark
|
||||
launcher
|
||||
router
|
||||
server
|
||||
client
|
||||
openssl.dev
|
||||
pkg-config
|
||||
cargo
|
||||
|
@ -3,9 +3,7 @@ import requests
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from text_generation.types import (
|
||||
Completion,
|
||||
)
|
||||
from text_generation.types import Completion, ChatComplete
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -50,6 +48,56 @@ def test_flash_llama_completion_single_prompt(
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
async def test_flash_llama_completion_stream_usage(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is Deep Learning?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
"stream_options": {"include_usage": True},
|
||||
"stream": True,
|
||||
}
|
||||
string = ""
|
||||
chunks = []
|
||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||
async with session.post(url, json=request) as response:
|
||||
# iterate over the stream
|
||||
async for chunk in response.content.iter_any():
|
||||
# remove "data:"
|
||||
chunk = chunk.decode().split("\n\n")
|
||||
# remove "data:" if present
|
||||
chunk = [c.replace("data:", "") for c in chunk]
|
||||
# remove empty strings
|
||||
chunk = [c for c in chunk if c]
|
||||
# remove completion marking chunk
|
||||
chunk = [c for c in chunk if c != " [DONE]"]
|
||||
# parse json
|
||||
chunk = [json.loads(c) for c in chunk]
|
||||
|
||||
for c in chunk:
|
||||
chunks.append(ChatComplete(**c))
|
||||
assert "choices" in c
|
||||
if len(c["choices"]) == 1:
|
||||
index = c["choices"][0]["index"]
|
||||
assert index == 0
|
||||
string += c["choices"][0]["text"]
|
||||
elif len(c["choices"]) == 0:
|
||||
assert c["usage"] is not None
|
||||
else:
|
||||
raise RuntimeError("Expected different payload")
|
||||
assert string == ""
|
||||
assert chunks == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||
response = requests.post(
|
||||
|
25
nix/client.nix
Normal file
25
nix/client.nix
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
buildPythonPackage,
|
||||
poetry-core,
|
||||
huggingface-hub,
|
||||
pydantic,
|
||||
}:
|
||||
|
||||
buildPythonPackage {
|
||||
name = "text-generation";
|
||||
|
||||
src = ../clients/python;
|
||||
|
||||
pyproject = true;
|
||||
|
||||
build-system = [ poetry-core ];
|
||||
|
||||
nativeBuildInputs = [ ];
|
||||
|
||||
pythonRemoveDeps = [ ];
|
||||
|
||||
dependencies = [
|
||||
huggingface-hub
|
||||
pydantic
|
||||
];
|
||||
}
|
@ -21,7 +21,6 @@
|
||||
loguru,
|
||||
mamba-ssm,
|
||||
marlin-kernels,
|
||||
moe-kernels,
|
||||
opentelemetry-api,
|
||||
opentelemetry-exporter-otlp,
|
||||
opentelemetry-instrumentation-grpc,
|
||||
@ -89,7 +88,6 @@ buildPythonPackage {
|
||||
loguru
|
||||
mamba-ssm
|
||||
marlin-kernels
|
||||
moe-kernels
|
||||
opentelemetry-api
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-grpc
|
||||
|
Loading…
Reference in New Issue
Block a user