Only send the usage when asked for.

This commit is contained in:
Nicolas Patry 2024-09-18 12:56:59 +02:00
parent 4716bd51ad
commit df287fe758
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 69 additions and 12 deletions

View File

@ -68,7 +68,7 @@ async def test_flash_llama_completion_stream_usage(
} }
string = "" string = ""
chunks = [] chunks = []
is_final = False had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session: async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response: async with session.post(url, json=request) as response:
# iterate over the stream # iterate over the stream
@ -93,18 +93,68 @@ async def test_flash_llama_completion_stream_usage(
string += c["choices"][0]["delta"]["content"] string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None has_usage = c["usage"] is not None
assert not is_final assert not had_usage
if has_usage: if has_usage:
is_final = True had_usage = True
else: else:
raise RuntimeError("Expected different payload") raise RuntimeError("Expected different payload")
assert is_final assert had_usage
assert ( assert (
string string
== "**Deep Learning: An Overview**\n=====================================\n\n" == "**Deep Learning: An Overview**\n=====================================\n\n"
) )
assert chunks == response_snapshot assert chunks == response_snapshot
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release @pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):

View File

@ -6,7 +6,7 @@
}: }:
buildPythonPackage { buildPythonPackage {
name = "text-generation-x"; name = "text-generation";
src = ../clients/python; src = ../clients/python;

View File

@ -1175,6 +1175,7 @@ async fn chat_completions(
seed, seed,
stop, stop,
stream, stream,
stream_options,
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
@ -1267,17 +1268,23 @@ async fn chat_completions(
let (usage, finish_reason) = match stream_token.details { let (usage, finish_reason) = match stream_token.details {
Some(details) => { Some(details) => {
let completion_tokens = details.generated_tokens; let usage = if stream_options
let prompt_tokens = details.input_length; .as_ref()
let total_tokens = prompt_tokens + completion_tokens; .map(|s| s.include_usage)
( .unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage { Some(Usage {
completion_tokens, completion_tokens,
prompt_tokens, prompt_tokens,
total_tokens, total_tokens,
}), })
Some(details.finish_reason.format(true)), } else {
) None
};
(usage, Some(details.finish_reason.format(true)))
} }
None => (None, None), None => (None, None),
}; };