diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 21ebc915..6c359f1e 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -68,7 +68,7 @@ async def test_flash_llama_completion_stream_usage( } string = "" chunks = [] - is_final = False + had_usage = False async with ClientSession(headers=flash_llama_completion.headers) as session: async with session.post(url, json=request) as response: # iterate over the stream @@ -93,18 +93,68 @@ async def test_flash_llama_completion_stream_usage( string += c["choices"][0]["delta"]["content"] has_usage = c["usage"] is not None - assert not is_final + assert not had_usage if has_usage: - is_final = True + had_usage = True else: raise RuntimeError("Expected different payload") - assert is_final + assert had_usage assert ( string == "**Deep Learning: An Overview**\n=====================================\n\n" ) assert chunks == response_snapshot + request = { + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is Deep Learning?", + } + ], + "max_tokens": 10, + "temperature": 0.0, + "stream": True, + } + string = "" + chunks = [] + had_usage = False + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: + # iterate over the stream + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] + # parse json + chunk = [json.loads(c) for c in chunk] + + for c in chunk: + chunks.append(ChatCompletionChunk(**c)) + assert "choices" in c + if len(c["choices"]) == 1: + index = c["choices"][0]["index"] + assert index == 0 + string += c["choices"][0]["delta"]["content"] + + has_usage = c["usage"] is not None + assert not had_usage + if has_usage: + had_usage = True + else: + raise RuntimeError("Expected different payload") + assert not had_usage + assert ( + string + == "**Deep Learning: An Overview**\n=====================================\n\n" + ) + @pytest.mark.release def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): diff --git a/nix/client.nix b/nix/client.nix index d9c9e0f5..a6a375b6 100644 --- a/nix/client.nix +++ b/nix/client.nix @@ -6,7 +6,7 @@ }: buildPythonPackage { - name = "text-generation-x"; + name = "text-generation"; src = ../clients/python; diff --git a/router/src/server.rs b/router/src/server.rs index 16ca1f10..176fd856 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1175,6 +1175,7 @@ async fn chat_completions( seed, stop, stream, + stream_options, tools, tool_choice, tool_prompt, @@ -1267,17 +1268,23 @@ async fn chat_completions( let (usage, finish_reason) = match stream_token.details { Some(details) => { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - ( + let usage = if stream_options + .as_ref() + .map(|s| s.include_usage) + .unwrap_or(false) + { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; Some(Usage { completion_tokens, prompt_tokens, total_tokens, - }), - Some(details.finish_reason.format(true)), - ) + }) + } else { + None + }; + (usage, Some(details.finish_reason.format(true))) } None => (None, None), };