mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* Attempt at automatic max batch prefill. * Taking into account number of shards. * Adding more cards. * Adding A100 + H100 * Adding a few more cards. * Logprobs cost too much. * h100 better name, and keep factor of 2 * Damn inflated sparse tflops. * Typo in h100. * Updated the flops calculation (checked with fvcore). * chunking by default. * Fix prefix caching for chat completion since we removed logprobs. * More tests. * Dropping all the prefill logprobs. * Add a flag that enables users to get logprobs back. * Repairing prompt token counting. * Fixing a few tests. * Remove some scaffolding. * Attempting to reduces the issues (workarounds for now).
96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
import pytest
|
|
import asyncio
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mllama_handle(launcher):
|
|
with launcher(
|
|
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
|
num_shard=2,
|
|
) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def mllama(mllama_handle):
|
|
await mllama_handle.health(300)
|
|
return mllama_handle.client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mllama_simpl(mllama, response_snapshot):
|
|
response = await mllama.chat(
|
|
max_tokens=10,
|
|
temperature=0.0,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Can you tell me a very short story based on the image?",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
|
},
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
|
|
assert response.usage == {
|
|
"completion_tokens": 10,
|
|
"prompt_tokens": 50,
|
|
"total_tokens": 60,
|
|
}
|
|
assert (
|
|
response.choices[0].message.content
|
|
== "In a bustling city, a chicken named Cluck"
|
|
)
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.release
|
|
@pytest.mark.asyncio
|
|
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
|
futures = [
|
|
mllama.chat(
|
|
max_tokens=10,
|
|
temperature=0.0,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Can you tell me a very short story based on the image?",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
|
},
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
# TODO with v3, 4 breaks here. Nothing accounts of the image VRAM
|
|
# because mllama is the only one doing its thing.
|
|
for i in range(2)
|
|
]
|
|
responses = await asyncio.gather(*futures)
|
|
|
|
_ = [response.choices[0].message.content for response in responses]
|
|
|
|
# XXX: TODO: Fix this test.
|
|
# assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
|
|
# assert len(generated_texts) == 4
|
|
# assert generated_texts, all(
|
|
# [text == generated_texts[0] for text in generated_texts]
|
|
# )
|
|
# assert responses == response_snapshot
|