Adding the assert.

This commit is contained in:
Nicolas Patry 2024-09-18 12:41:21 +02:00
parent 162549f37f
commit 4716bd51ad
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
6 changed files with 253 additions and 8 deletions

View File

@ -168,7 +168,7 @@ class ChatCompletionComplete(BaseModel):
# Log probabilities for the chat completion # Log probabilities for the chat completion
logprobs: Optional[Any] logprobs: Optional[Any]
# Reason for completion # Reason for completion
finish_reason: str finish_reason: Optional[str]
# Usage details of the chat completion # Usage details of the chat completion
usage: Optional[Any] = None usage: Optional[Any] = None
@ -191,6 +191,7 @@ class ChatCompletionChunk(BaseModel):
model: str model: str
system_fingerprint: str system_fingerprint: str
choices: List[Choice] choices: List[Choice]
usage: Optional[Any]
class Parameters(BaseModel): class Parameters(BaseModel):

View File

@ -70,6 +70,34 @@
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { }; client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
in in
{ {
checks = {
rust =
with pkgs;
rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [
clippy
pkg-config
protobuf
python3
rustfmt
];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
};
};
formatter = pkgs.nixfmt-rfc-style; formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec { devShells = with pkgs; rec {
default = pure; default = pure;

View File

@ -0,0 +1,206 @@
[
{
"choices": [
{
"delta": {
"content": "**",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "Deep",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Learning",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": ":",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " An",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Overview",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "**\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "================================",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "=====",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "\n\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 40,
"total_tokens": 50
}
}
]

View File

@ -3,7 +3,7 @@ import requests
import json import json
from aiohttp import ClientSession from aiohttp import ClientSession
from text_generation.types import Completion, ChatComplete from text_generation.types import Completion, ChatCompletionChunk
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -68,6 +68,7 @@ async def test_flash_llama_completion_stream_usage(
} }
string = "" string = ""
chunks = [] chunks = []
is_final = False
async with ClientSession(headers=flash_llama_completion.headers) as session: async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response: async with session.post(url, json=request) as response:
# iterate over the stream # iterate over the stream
@ -84,17 +85,24 @@ async def test_flash_llama_completion_stream_usage(
chunk = [json.loads(c) for c in chunk] chunk = [json.loads(c) for c in chunk]
for c in chunk: for c in chunk:
chunks.append(ChatComplete(**c)) chunks.append(ChatCompletionChunk(**c))
assert "choices" in c assert "choices" in c
if len(c["choices"]) == 1: if len(c["choices"]) == 1:
index = c["choices"][0]["index"] index = c["choices"][0]["index"]
assert index == 0 assert index == 0
string += c["choices"][0]["text"] string += c["choices"][0]["delta"]["content"]
elif len(c["choices"]) == 0:
assert c["usage"] is not None has_usage = c["usage"] is not None
assert not is_final
if has_usage:
is_final = True
else: else:
raise RuntimeError("Expected different payload") raise RuntimeError("Expected different payload")
assert string == "" assert is_final
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot assert chunks == response_snapshot

View File

@ -6,7 +6,7 @@
}: }:
buildPythonPackage { buildPythonPackage {
name = "text-generation"; name = "text-generation-x";
src = ../clients/python; src = ../clients/python;

View File

@ -21,6 +21,7 @@
loguru, loguru,
mamba-ssm, mamba-ssm,
marlin-kernels, marlin-kernels,
moe-kernels,
opentelemetry-api, opentelemetry-api,
opentelemetry-exporter-otlp, opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc, opentelemetry-instrumentation-grpc,
@ -88,6 +89,7 @@ buildPythonPackage {
loguru loguru
mamba-ssm mamba-ssm
marlin-kernels marlin-kernels
moe-kernels
opentelemetry-api opentelemetry-api
opentelemetry-exporter-otlp opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc opentelemetry-instrumentation-grpc