diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 87e76df1..fa1f9f61 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -22,6 +22,8 @@ The following models are optimized and can be served with TGI, which uses custom - [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - [Phi](https://huggingface.co/microsoft/phi-2) +- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal) +- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal) If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 022b2298..e8ce0d84 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -277,6 +277,8 @@ def launcher(event_loop): disable_grammar_support: bool = False, dtype: Optional[str] = None, revision: Optional[str] = None, + max_input_length: Optional[int] = None, + max_total_tokens: Optional[int] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -314,6 +316,12 @@ def launcher(event_loop): args.append(revision) if trust_remote_code: args.append("--trust-remote-code") + if max_input_length: + args.append("--max-input-length") + args.append(str(max_input_length)) + if max_total_tokens: + args.append("--max-total-tokens") + args.append(str(max_total_tokens)) env["LOG_LEVEL"] = "info,text_generation_router=debug" @@ -347,6 +355,8 @@ def launcher(event_loop): disable_grammar_support: bool = False, dtype: Optional[str] = None, revision: Optional[str] = None, + max_input_length: Optional[int] = None, + max_total_tokens: Optional[int] = None, ): port = random.randint(8000, 10_000) @@ -367,6 +377,12 @@ def launcher(event_loop): args.append(revision) if trust_remote_code: args.append("--trust-remote-code") + if max_input_length: + args.append("--max-input-length") + args.append(str(max_input_length)) + if max_total_tokens: + args.append("--max-total-tokens") + args.append(str(max_total_tokens)) client = docker.from_env() diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index 882971f2..7fb70a8f 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -33,6 +33,7 @@ async def test_idefics(idefics, response_snapshot): ) assert response.details.generated_tokens == 10 + assert response.generated_text == "\n\nDeep learning is a new type of machine" assert response == response_snapshot @@ -48,6 +49,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] + assert generated_texts[0] == "\n\nDeep learning is a new type of machine" assert len(generated_texts) == 4 assert generated_texts, all( [text == generated_texts[0] for text in generated_texts] diff --git a/integration-tests/models/test_llava_next.py b/integration-tests/models/test_llava_next.py new file mode 100644 index 00000000..be968a4b --- /dev/null +++ b/integration-tests/models/test_llava_next.py @@ -0,0 +1,82 @@ +import pytest +import base64 + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.fixture(scope="module") +def flash_llava_next_handle(launcher): + with launcher( + "llava-hf/llava-v1.6-mistral-7b-hf", + num_shard=4, + max_input_length=4000, + max_total_tokens=4096, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llava_next(flash_llava_next_handle): + await flash_llava_next_handle.health(300) + return flash_llava_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llava_next_simple(flash_llava_next, response_snapshot): + chicken = get_chicken() + response = await flash_llava_next.generate( + f"User:![]({chicken})Can you tell me a very short story based on the image?", + max_new_tokens=10, + ) + assert response.generated_text == "toto" + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot): + response = await flash_llava_next.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llava_next_load( + flash_llava_next, generate_load, response_snapshot +): + chicken = get_chicken() + responses = await generate_load( + flash_llava_next, + f"User:![]({chicken})Can you tell me a very short story based on the image?", + max_new_tokens=10, + n=4, + ) + generated_texts = [r.generated_text for r in responses] + assert generated_texts[0] == "\n\nDeep learning is a new type of machine" + assert len(generated_texts) == 4 + assert all([r.generated_text == generated_texts[0] for r in responses]) + + assert responses == response_snapshot