Only send the usage when asked for.

This commit is contained in:
Nicolas Patry 2024-09-18 12:56:59 +02:00
parent 4716bd51ad
commit df287fe758
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 69 additions and 12 deletions

View File

@ -68,7 +68,7 @@ async def test_flash_llama_completion_stream_usage(
}
string = ""
chunks = []
is_final = False
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
@ -93,18 +93,68 @@ async def test_flash_llama_completion_stream_usage(
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not is_final
assert not had_usage
if has_usage:
is_final = True
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert is_final
assert had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):

View File

@ -6,7 +6,7 @@
}:
buildPythonPackage {
name = "text-generation-x";
name = "text-generation";
src = ../clients/python;

View File

@ -1175,6 +1175,7 @@ async fn chat_completions(
seed,
stop,
stream,
stream_options,
tools,
tool_choice,
tool_prompt,
@ -1267,17 +1268,23 @@ async fn chat_completions(
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
(
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
}),
Some(details.finish_reason.format(true)),
)
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};