feat(test): add more models to integration tests

This commit is contained in:
baptiste 2025-04-04 07:40:37 +00:00
parent 064c7bd03c
commit 32b00e039b
6 changed files with 319 additions and 51 deletions

View File

@ -54,3 +54,9 @@ run-integration-tests:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py

View File

@ -109,6 +109,11 @@ Then run the following command to run the integration tests:
make -C backends/gaudi run-integration-tests
```
To capture the expected outputs for the integration tests, you can run the following command:
```bash
make -C backends/gaudi capture-expected-outputs-for-integration-tests
```
#### How the integration tests works
The integration tests works as follows:

View File

@ -0,0 +1,85 @@
import json
import os
from typing import Dict, Any, Generator
import pytest
from test_model import TEST_CONFIGS
UNKNOWN_CONFIGS = {
name: config
for name, config in TEST_CONFIGS.items()
if config["expected_greedy_output"] == "unknown"
or config["expected_batch_output"] == "unknown"
}
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
"""Fixture that provides model configurations for testing."""
test_config = UNKNOWN_CONFIGS[request.param]
test_config["test_name"] = request.param
return test_config
@pytest.fixture(scope="module")
def test_name(test_config):
yield test_config["test_name"]
@pytest.fixture(scope="module")
def tgi_service(launcher, test_config, test_name) -> Generator:
"""Fixture that provides a TGI service for testing."""
with launcher(test_config["model_id"], test_name) as service:
yield service
@pytest.mark.asyncio
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
"""Test that captures expected outputs for models with unknown outputs."""
print(f"Testing {test_name} with {test_config['model_id']}")
# Wait for service to be ready
await tgi_service.health(1000)
client = tgi_service.client
# Test single request (greedy)
print("Testing single request...")
response = await client.generate(
test_config["input"],
max_new_tokens=32,
)
greedy_output = response.generated_text
# Test multiple requests (batch)
print("Testing batch requests...")
responses = []
for _ in range(4):
response = await client.generate(
test_config["input"],
max_new_tokens=32,
)
responses.append(response.generated_text)
# Store results in a JSON file
output_file = "server/integration-tests/expected_outputs.json"
results = {}
# Try to load existing results if file exists
if os.path.exists(output_file):
with open(output_file, "r") as f:
results = json.load(f)
# Update results for this model
results[test_name] = {
"model_id": test_config["model_id"],
"input": test_config["input"],
"greedy_output": greedy_output,
"batch_outputs": responses,
"args": test_config["args"],
}
# Save updated results
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults for {test_name} saved to {output_file}")

View File

@ -8,6 +8,7 @@ import threading
import time
from tempfile import TemporaryDirectory
from typing import List
import socket
import docker
import pytest
@ -73,12 +74,12 @@ logger.add(
# signal.signal(signal.SIGTERM, cleanup_handler)
def stream_container_logs(container):
def stream_container_logs(container, test_name):
"""Stream container logs in a separate thread."""
try:
for log in container.logs(stream=True, follow=True):
print(
"[TGI Server Logs] " + log.decode("utf-8"),
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
end="",
file=sys.stderr,
flush=True,
@ -178,11 +179,21 @@ def launcher(data_volume):
logger.info(
f"Starting docker launcher for model {model_id} and test {test_name}"
)
port = 8080
# Get a random available port
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
port = get_free_port()
logger.debug(f"Using port {port}")
client = docker.from_env()
container_name = f"tgi-gaudi-test-{test_name}"
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
try:
container = client.containers.get(container_name)
@ -225,7 +236,7 @@ def launcher(data_volume):
environment=env,
detach=True,
volumes=volumes,
ports={"80/tcp": 8080},
ports={"80/tcp": port},
**HABANA_RUN_ARGS,
)
@ -234,7 +245,7 @@ def launcher(data_volume):
# Start log streaming in a background thread
log_thread = threading.Thread(
target=stream_container_logs,
args=(container,),
args=(container, test_name),
daemon=True, # This ensures the thread will be killed when the main program exits
)
log_thread.start()

View File

@ -4,8 +4,17 @@ from text_generation import AsyncClient
import pytest
from Levenshtein import distance as levenshtein_distance
# Model that are not working but should... :(
# - google/gemma-2-2b-it
# - google/flan-t5-large
# - codellama/CodeLlama-13b-hf
# - ibm-granite/granite-3.0-8b-instruct
# - microsoft/Phi-3.5-MoE-instruct
# - microsoft/Phi-3-mini-4k-instruct
# The config in args is not optimized for speed but only check that inference is working for the different models architectures
TEST_CONFIGS = {
"llama3-8b-shared-gaudi1": {
"meta-llama/Llama-3.1-8B-Instruct-shared": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use",
@ -25,11 +34,10 @@ TEST_CONFIGS = {
"2048",
],
},
"llama3-8b-gaudi1": {
"meta-llama/Llama-3.1-8B-Instruct": {
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
# "expected_sampling_output": " Part 2: Backpropagation\nIn my last post , I introduced the concept of deep learning and covered some high-level topics, including feedforward neural networks.",
"expected_batch_output": " A Beginners Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
"env_config": {},
"args": [
@ -43,8 +51,160 @@ TEST_CONFIGS = {
"2048",
],
},
"meta-llama/Llama-2-7b-chat-hf": {
"model_id": "meta-llama/Llama-2-7b-chat-hf",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"mistralai/Mistral-7B-Instruct-v0.3": {
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"bigcode/starcoder2-3b": {
"model_id": "bigcode/starcoder2-3b",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"google/gemma-7b-it": {
"model_id": "google/gemma-7b-it",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"Qwen/Qwen2-0.5B-Instruct": {
"model_id": "Qwen/Qwen2-0.5B-Instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
"--max-batch-prefill-tokens",
"2048",
],
},
"tiiuae/falcon-7b-instruct": {
"model_id": "tiiuae/falcon-7b-instruct",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"microsoft/phi-1_5": {
"model_id": "microsoft/phi-1_5",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"openai-community/gpt2": {
"model_id": "openai-community/gpt2",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"facebook/opt-125m": {
"model_id": "facebook/opt-125m",
"input": "What is Deep Learning?",
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
"EleutherAI/gpt-j-6b": {
"model_id": "EleutherAI/gpt-j-6b",
"input": "What is Deep Learning?",
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
"args": [
"--max-input-tokens",
"512",
"--max-total-tokens",
"1024",
"--max-batch-size",
"4",
],
},
}
print(f"Testing {len(TEST_CONFIGS)} models")
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
def test_config(request) -> Dict[str, Any]:
@ -102,22 +262,6 @@ async def test_model_single_request(
assert response.details.generated_tokens == 32
assert response.generated_text == expected_outputs["greedy"]
# TO FIX: Add sampling tests later, the gaudi backend seems to be flaky with the sampling even with a fixed seed
# Sampling
# response = await tgi_client.generate(
# "What is Deep Learning?",
# do_sample=True,
# top_k=50,
# top_p=0.9,
# repetition_penalty=1.2,
# max_new_tokens=100,
# seed=42,
# decoder_input_details=True,
# )
# assert expected_outputs["sampling"] in response.generated_text
@pytest.mark.asyncio
async def test_model_multiple_requests(

View File

@ -41,34 +41,12 @@ curl 127.0.0.1:8080/generate \
## How-to Guides
### How to Run Specific Models
You can view the full list of supported models in the [Supported Models](https://huggingface.co/docs/text-generation-inference/backends/gaudi#supported-models) section.
The following models have been validated on Gaudi2:
| Model | Model ID | BF16 | | FP8 | |
|-----------------------|----------------------------------------|-------------|------------|-------------|------------|
| | | Single Card | Multi-Card | Single Card | Multi-Card |
| Llama2-7B | meta-llama/Llama-2-7b-chat-hf | ✔ | ✔ | ✔ | ✔ |
| Llama2-70B | meta-llama/Llama-2-70b-chat-hf | | ✔ | | ✔ |
| Llama3-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
| Llama3-70B | meta-llama/Meta-Llama-3-70B-Instruct | | ✔ | | ✔ |
| Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
| Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B-Instruct | | ✔ | | ✔ |
| CodeLlama-13B | codellama/CodeLlama-13b-hf | ✔ | ✔ | ✔ | ✔ |
| Mixtral-8x7B | mistralai/Mixtral-8x7B-Instruct-v0.1 | ✔ | ✔ | ✔ | ✔ |
| Mistral-7B | mistralai/Mistral-7B-Instruct-v0.3 | ✔ | ✔ | ✔ | ✔ |
| Falcon-180B | tiiuae/falcon-180B-chat | | ✔ | | ✔ |
| Qwen2-72B | Qwen/Qwen2-72B-Instruct | | ✔ | | ✔ |
| Starcoder2-3b | bigcode/starcoder2-3b | ✔ | ✔ | ✔ | |
| Starcoder2-15b | bigcode/starcoder2-15b | ✔ | ✔ | ✔ | |
| Starcoder | bigcode/starcoder | ✔ | ✔ | ✔ | ✔ |
| Gemma-7b | google/gemma-7b-it | ✔ | ✔ | ✔ | ✔ |
| Llava-v1.6-Mistral-7B | llava-hf/llava-v1.6-mistral-7b-hf | ✔ | ✔ | ✔ | ✔ |
To run any of these models:
For example, to run Llama3.1-8B, you can use the following command:
```bash
model=MODEL_ID_THAT_YOU_WANT_TO_RUN
model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
hf_token=YOUR_ACCESS_TOKEN
@ -87,7 +65,7 @@ The validated docker commands can be found in the [examples/docker_commands fold
### How to Enable Multi-Card Inference (Sharding)
TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards.
TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. This is recommended to run large models and to speed up inference.
For example, on a machine with 8 Gaudi cards, you can run:
@ -270,6 +248,45 @@ Note: Model warmup can take several minutes, especially for FP8 inference. For f
This section contains reference information about the Gaudi backend.
### Supported Models
Text Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi.
**Large Language Models (LLMs)**
- [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
- [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
- [Llama3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
- [LLama3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
- [LLama3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
- [CodeLlama-13B](https://huggingface.co/codellama/CodeLlama-13b-hf)
- [Opt-125m](https://huggingface.co/facebook/opt-125m)
- [OpenAI-gpt2](https://huggingface.co/openai-community/gpt2)
- [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
- [Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B-Instruct)
- [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5)
- [Gemma-7b](https://huggingface.co/google/gemma-7b-it)
- [Starcoder2-3b](https://huggingface.co/bigcode/starcoder2-3b)
- [Starcoder2-15b](https://huggingface.co/bigcode/starcoder2-15b)
- [Starcoder](https://huggingface.co/bigcode/starcoder)
- [falcon-7b-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [Falcon-180B](https://huggingface.co/tiiuae/falcon-180B-chat)
- [GPT-2](https://huggingface.co/openai-community/gpt2)
- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)
**Vision-Language Models (VLMs)**
- [LLaVA-v1.6-Mistral-7B](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)
- [Idefics 2.5](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
- [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
- [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)
We also support on a best effort basis models with different parameters count that use the same model architecture but those models were not tested. For example, the gaudi backend supports `meta-llama/Llama-3.2-1B` as the architecture is the standard llama3 architecture. If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).
### Environment Variables
The following table contains the environment variables that can be used to configure the Gaudi backend: