mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Only send the usage when asked for.
This commit is contained in:
parent
4716bd51ad
commit
df287fe758
@ -68,7 +68,7 @@ async def test_flash_llama_completion_stream_usage(
|
|||||||
}
|
}
|
||||||
string = ""
|
string = ""
|
||||||
chunks = []
|
chunks = []
|
||||||
is_final = False
|
had_usage = False
|
||||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
async with session.post(url, json=request) as response:
|
async with session.post(url, json=request) as response:
|
||||||
# iterate over the stream
|
# iterate over the stream
|
||||||
@ -93,18 +93,68 @@ async def test_flash_llama_completion_stream_usage(
|
|||||||
string += c["choices"][0]["delta"]["content"]
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
has_usage = c["usage"] is not None
|
has_usage = c["usage"] is not None
|
||||||
assert not is_final
|
assert not had_usage
|
||||||
if has_usage:
|
if has_usage:
|
||||||
is_final = True
|
had_usage = True
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Expected different payload")
|
raise RuntimeError("Expected different payload")
|
||||||
assert is_final
|
assert had_usage
|
||||||
assert (
|
assert (
|
||||||
string
|
string
|
||||||
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
)
|
)
|
||||||
assert chunks == response_snapshot
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is Deep Learning?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
string = ""
|
||||||
|
chunks = []
|
||||||
|
had_usage = False
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(ChatCompletionChunk(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
if len(c["choices"]) == 1:
|
||||||
|
index = c["choices"][0]["index"]
|
||||||
|
assert index == 0
|
||||||
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
|
has_usage = c["usage"] is not None
|
||||||
|
assert not had_usage
|
||||||
|
if has_usage:
|
||||||
|
had_usage = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected different payload")
|
||||||
|
assert not had_usage
|
||||||
|
assert (
|
||||||
|
string
|
||||||
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
}:
|
}:
|
||||||
|
|
||||||
buildPythonPackage {
|
buildPythonPackage {
|
||||||
name = "text-generation-x";
|
name = "text-generation";
|
||||||
|
|
||||||
src = ../clients/python;
|
src = ../clients/python;
|
||||||
|
|
||||||
|
@ -1175,6 +1175,7 @@ async fn chat_completions(
|
|||||||
seed,
|
seed,
|
||||||
stop,
|
stop,
|
||||||
stream,
|
stream,
|
||||||
|
stream_options,
|
||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
@ -1267,17 +1268,23 @@ async fn chat_completions(
|
|||||||
|
|
||||||
let (usage, finish_reason) = match stream_token.details {
|
let (usage, finish_reason) = match stream_token.details {
|
||||||
Some(details) => {
|
Some(details) => {
|
||||||
|
let usage = if stream_options
|
||||||
|
.as_ref()
|
||||||
|
.map(|s| s.include_usage)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
let completion_tokens = details.generated_tokens;
|
let completion_tokens = details.generated_tokens;
|
||||||
let prompt_tokens = details.input_length;
|
let prompt_tokens = details.input_length;
|
||||||
let total_tokens = prompt_tokens + completion_tokens;
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
(
|
|
||||||
Some(Usage {
|
Some(Usage {
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
}),
|
})
|
||||||
Some(details.finish_reason.format(true)),
|
} else {
|
||||||
)
|
None
|
||||||
|
};
|
||||||
|
(usage, Some(details.finish_reason.format(true)))
|
||||||
}
|
}
|
||||||
None => (None, None),
|
None => (None, None),
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user