mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fetch stuff from nix integration test for easier testing.
This commit is contained in:
parent
678721bcf0
commit
162549f37f
12
flake.lock
12
flake.lock
@ -479,11 +479,11 @@
|
|||||||
"systems": "systems_6"
|
"systems": "systems_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1710146030,
|
"lastModified": 1726560853,
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -853,11 +853,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726280639,
|
"lastModified": 1726626348,
|
||||||
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
|
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
|
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
30
flake.nix
30
flake.nix
@ -67,31 +67,10 @@
|
|||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||||
|
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
checks = {
|
|
||||||
rust = with pkgs; rustPlatform.buildRustPackage {
|
|
||||||
name = "rust-checks";
|
|
||||||
src = ./.;
|
|
||||||
cargoLock = {
|
|
||||||
lockFile = ./Cargo.lock;
|
|
||||||
};
|
|
||||||
buildInputs = [ openssl.dev ];
|
|
||||||
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
|
|
||||||
buildPhase = ''
|
|
||||||
cargo check
|
|
||||||
'';
|
|
||||||
checkPhase = ''
|
|
||||||
cargo fmt -- --check
|
|
||||||
cargo test -j $NIX_BUILD_CORES
|
|
||||||
cargo clippy
|
|
||||||
'';
|
|
||||||
installPhase = "touch $out";
|
|
||||||
} ;
|
|
||||||
};
|
|
||||||
|
|
||||||
formatter = pkgs.nixfmt-rfc-style;
|
formatter = pkgs.nixfmt-rfc-style;
|
||||||
|
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
default = pure;
|
default = pure;
|
||||||
|
|
||||||
@ -106,10 +85,11 @@
|
|||||||
test = mkShell {
|
test = mkShell {
|
||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
# benchmark
|
benchmark
|
||||||
# launcher
|
launcher
|
||||||
# router
|
router
|
||||||
server
|
server
|
||||||
|
client
|
||||||
openssl.dev
|
openssl.dev
|
||||||
pkg-config
|
pkg-config
|
||||||
cargo
|
cargo
|
||||||
|
@ -3,9 +3,7 @@ import requests
|
|||||||
import json
|
import json
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from text_generation.types import (
|
from text_generation.types import Completion, ChatComplete
|
||||||
Completion,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -50,6 +48,56 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
async def test_flash_llama_completion_stream_usage(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
|
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is Deep Learning?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
string = ""
|
||||||
|
chunks = []
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(ChatComplete(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
if len(c["choices"]) == 1:
|
||||||
|
index = c["choices"][0]["index"]
|
||||||
|
assert index == 0
|
||||||
|
string += c["choices"][0]["text"]
|
||||||
|
elif len(c["choices"]) == 0:
|
||||||
|
assert c["usage"] is not None
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected different payload")
|
||||||
|
assert string == ""
|
||||||
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
25
nix/client.nix
Normal file
25
nix/client.nix
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
buildPythonPackage,
|
||||||
|
poetry-core,
|
||||||
|
huggingface-hub,
|
||||||
|
pydantic,
|
||||||
|
}:
|
||||||
|
|
||||||
|
buildPythonPackage {
|
||||||
|
name = "text-generation";
|
||||||
|
|
||||||
|
src = ../clients/python;
|
||||||
|
|
||||||
|
pyproject = true;
|
||||||
|
|
||||||
|
build-system = [ poetry-core ];
|
||||||
|
|
||||||
|
nativeBuildInputs = [ ];
|
||||||
|
|
||||||
|
pythonRemoveDeps = [ ];
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
huggingface-hub
|
||||||
|
pydantic
|
||||||
|
];
|
||||||
|
}
|
@ -21,7 +21,6 @@
|
|||||||
loguru,
|
loguru,
|
||||||
mamba-ssm,
|
mamba-ssm,
|
||||||
marlin-kernels,
|
marlin-kernels,
|
||||||
moe-kernels,
|
|
||||||
opentelemetry-api,
|
opentelemetry-api,
|
||||||
opentelemetry-exporter-otlp,
|
opentelemetry-exporter-otlp,
|
||||||
opentelemetry-instrumentation-grpc,
|
opentelemetry-instrumentation-grpc,
|
||||||
@ -89,7 +88,6 @@ buildPythonPackage {
|
|||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
marlin-kernels
|
marlin-kernels
|
||||||
moe-kernels
|
|
||||||
opentelemetry-api
|
opentelemetry-api
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-grpc
|
opentelemetry-instrumentation-grpc
|
||||||
|
Loading…
Reference in New Issue
Block a user