diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index bd1f1eeae..f760f4d6e 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -54,3 +54,9 @@ run-integration-tests: DOCKER_VOLUME=${root_dir}/data \ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \ uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests + +# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests +capture-expected-outputs-for-integration-tests: + DOCKER_VOLUME=${root_dir}/data \ + HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \ + uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py diff --git a/backends/gaudi/README.md b/backends/gaudi/README.md index 9c7367e5b..ba890f0b1 100644 --- a/backends/gaudi/README.md +++ b/backends/gaudi/README.md @@ -109,6 +109,11 @@ Then run the following command to run the integration tests: make -C backends/gaudi run-integration-tests ``` +To capture the expected outputs for the integration tests, you can run the following command: +```bash +make -C backends/gaudi capture-expected-outputs-for-integration-tests +``` + #### How the integration tests works The integration tests works as follows: diff --git a/backends/gaudi/server/integration-tests/capture_expected_outputs.py b/backends/gaudi/server/integration-tests/capture_expected_outputs.py new file mode 100644 index 000000000..051b9d698 --- /dev/null +++ b/backends/gaudi/server/integration-tests/capture_expected_outputs.py @@ -0,0 +1,85 @@ +import json +import os +from typing import Dict, Any, Generator + +import pytest +from test_model import TEST_CONFIGS + +UNKNOWN_CONFIGS = { + name: config + for name, config in TEST_CONFIGS.items() + if config["expected_greedy_output"] == "unknown" + or config["expected_batch_output"] == "unknown" +} + + +@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys()) +def test_config(request) -> Dict[str, Any]: + """Fixture that provides model configurations for testing.""" + test_config = UNKNOWN_CONFIGS[request.param] + test_config["test_name"] = request.param + return test_config + + +@pytest.fixture(scope="module") +def test_name(test_config): + yield test_config["test_name"] + + +@pytest.fixture(scope="module") +def tgi_service(launcher, test_config, test_name) -> Generator: + """Fixture that provides a TGI service for testing.""" + with launcher(test_config["model_id"], test_name) as service: + yield service + + +@pytest.mark.asyncio +async def test_capture_expected_outputs(tgi_service, test_config, test_name): + """Test that captures expected outputs for models with unknown outputs.""" + print(f"Testing {test_name} with {test_config['model_id']}") + + # Wait for service to be ready + await tgi_service.health(1000) + client = tgi_service.client + + # Test single request (greedy) + print("Testing single request...") + response = await client.generate( + test_config["input"], + max_new_tokens=32, + ) + greedy_output = response.generated_text + + # Test multiple requests (batch) + print("Testing batch requests...") + responses = [] + for _ in range(4): + response = await client.generate( + test_config["input"], + max_new_tokens=32, + ) + responses.append(response.generated_text) + + # Store results in a JSON file + output_file = "server/integration-tests/expected_outputs.json" + results = {} + + # Try to load existing results if file exists + if os.path.exists(output_file): + with open(output_file, "r") as f: + results = json.load(f) + + # Update results for this model + results[test_name] = { + "model_id": test_config["model_id"], + "input": test_config["input"], + "greedy_output": greedy_output, + "batch_outputs": responses, + "args": test_config["args"], + } + + # Save updated results + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + + print(f"\nResults for {test_name} saved to {output_file}") diff --git a/backends/gaudi/server/integration-tests/conftest.py b/backends/gaudi/server/integration-tests/conftest.py index 9ead5c45f..aa50f833d 100644 --- a/backends/gaudi/server/integration-tests/conftest.py +++ b/backends/gaudi/server/integration-tests/conftest.py @@ -8,6 +8,7 @@ import threading import time from tempfile import TemporaryDirectory from typing import List +import socket import docker import pytest @@ -73,12 +74,12 @@ logger.add( # signal.signal(signal.SIGTERM, cleanup_handler) -def stream_container_logs(container): +def stream_container_logs(container, test_name): """Stream container logs in a separate thread.""" try: for log in container.logs(stream=True, follow=True): print( - "[TGI Server Logs] " + log.decode("utf-8"), + f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}", end="", file=sys.stderr, flush=True, @@ -178,11 +179,21 @@ def launcher(data_volume): logger.info( f"Starting docker launcher for model {model_id} and test {test_name}" ) - port = 8080 + + # Get a random available port + def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + port = get_free_port() + logger.debug(f"Using port {port}") client = docker.from_env() - container_name = f"tgi-gaudi-test-{test_name}" + container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}" try: container = client.containers.get(container_name) @@ -225,7 +236,7 @@ def launcher(data_volume): environment=env, detach=True, volumes=volumes, - ports={"80/tcp": 8080}, + ports={"80/tcp": port}, **HABANA_RUN_ARGS, ) @@ -234,7 +245,7 @@ def launcher(data_volume): # Start log streaming in a background thread log_thread = threading.Thread( target=stream_container_logs, - args=(container,), + args=(container, test_name), daemon=True, # This ensures the thread will be killed when the main program exits ) log_thread.start() diff --git a/backends/gaudi/server/integration-tests/test_model.py b/backends/gaudi/server/integration-tests/test_model.py index f6c58650d..7185a56a0 100644 --- a/backends/gaudi/server/integration-tests/test_model.py +++ b/backends/gaudi/server/integration-tests/test_model.py @@ -4,8 +4,17 @@ from text_generation import AsyncClient import pytest from Levenshtein import distance as levenshtein_distance +# Model that are not working but should... :( +# - google/gemma-2-2b-it +# - google/flan-t5-large +# - codellama/CodeLlama-13b-hf +# - ibm-granite/granite-3.0-8b-instruct +# - microsoft/Phi-3.5-MoE-instruct +# - microsoft/Phi-3-mini-4k-instruct + +# The config in args is not optimized for speed but only check that inference is working for the different models architectures TEST_CONFIGS = { - "llama3-8b-shared-gaudi1": { + "meta-llama/Llama-3.1-8B-Instruct-shared": { "model_id": "meta-llama/Llama-3.1-8B-Instruct", "input": "What is Deep Learning?", "expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", @@ -25,11 +34,10 @@ TEST_CONFIGS = { "2048", ], }, - "llama3-8b-gaudi1": { + "meta-llama/Llama-3.1-8B-Instruct": { "model_id": "meta-llama/Llama-3.1-8B-Instruct", "input": "What is Deep Learning?", "expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of", - # "expected_sampling_output": " Part 2: Backpropagation\nIn my last post , I introduced the concept of deep learning and covered some high-level topics, including feedforward neural networks.", "expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of", "env_config": {}, "args": [ @@ -43,8 +51,160 @@ TEST_CONFIGS = { "2048", ], }, + "meta-llama/Llama-2-7b-chat-hf": { + "model_id": "meta-llama/Llama-2-7b-chat-hf", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific", + "expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + "--max-batch-prefill-tokens", + "2048", + ], + }, + "mistralai/Mistral-7B-Instruct-v0.3": { + "model_id": "mistralai/Mistral-7B-Instruct-v0.3", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured", + "expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + "--max-batch-prefill-tokens", + "2048", + ], + }, + "bigcode/starcoder2-3b": { + "model_id": "bigcode/starcoder2-3b", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that", + "expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + "--max-batch-prefill-tokens", + "2048", + ], + }, + "google/gemma-7b-it": { + "model_id": "google/gemma-7b-it", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of", + "expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + "--max-batch-prefill-tokens", + "2048", + ], + }, + "Qwen/Qwen2-0.5B-Instruct": { + "model_id": "Qwen/Qwen2-0.5B-Instruct", + "input": "What is Deep Learning?", + "expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models", + "expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + "--max-batch-prefill-tokens", + "2048", + ], + }, + "tiiuae/falcon-7b-instruct": { + "model_id": "tiiuae/falcon-7b-instruct", + "input": "What is Deep Learning?", + "expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a", + "expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + ], + }, + "microsoft/phi-1_5": { + "model_id": "microsoft/phi-1_5", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large", + "expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + ], + }, + "openai-community/gpt2": { + "model_id": "openai-community/gpt2", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a", + "expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + ], + }, + "facebook/opt-125m": { + "model_id": "facebook/opt-125m", + "input": "What is Deep Learning?", + "expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout", + "expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + ], + }, + "EleutherAI/gpt-j-6b": { + "model_id": "EleutherAI/gpt-j-6b", + "input": "What is Deep Learning?", + "expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by", + "expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by", + "args": [ + "--max-input-tokens", + "512", + "--max-total-tokens", + "1024", + "--max-batch-size", + "4", + ], + }, } +print(f"Testing {len(TEST_CONFIGS)} models") + @pytest.fixture(scope="module", params=TEST_CONFIGS.keys()) def test_config(request) -> Dict[str, Any]: @@ -102,22 +262,6 @@ async def test_model_single_request( assert response.details.generated_tokens == 32 assert response.generated_text == expected_outputs["greedy"] - # TO FIX: Add sampling tests later, the gaudi backend seems to be flaky with the sampling even with a fixed seed - - # Sampling - # response = await tgi_client.generate( - # "What is Deep Learning?", - # do_sample=True, - # top_k=50, - # top_p=0.9, - # repetition_penalty=1.2, - # max_new_tokens=100, - # seed=42, - # decoder_input_details=True, - # ) - - # assert expected_outputs["sampling"] in response.generated_text - @pytest.mark.asyncio async def test_model_multiple_requests( diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx index 7e0e69ab1..58413f54f 100644 --- a/docs/source/backends/gaudi.mdx +++ b/docs/source/backends/gaudi.mdx @@ -41,34 +41,12 @@ curl 127.0.0.1:8080/generate \ ## How-to Guides -### How to Run Specific Models +You can view the full list of supported models in the [Supported Models](https://huggingface.co/docs/text-generation-inference/backends/gaudi#supported-models) section. -The following models have been validated on Gaudi2: - -| Model | Model ID | BF16 | | FP8 | | -|-----------------------|----------------------------------------|-------------|------------|-------------|------------| -| | | Single Card | Multi-Card | Single Card | Multi-Card | -| Llama2-7B | meta-llama/Llama-2-7b-chat-hf | ✔ | ✔ | ✔ | ✔ | -| Llama2-70B | meta-llama/Llama-2-70b-chat-hf | | ✔ | | ✔ | -| Llama3-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ | -| Llama3-70B | meta-llama/Meta-Llama-3-70B-Instruct | | ✔ | | ✔ | -| Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ | -| Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B-Instruct | | ✔ | | ✔ | -| CodeLlama-13B | codellama/CodeLlama-13b-hf | ✔ | ✔ | ✔ | ✔ | -| Mixtral-8x7B | mistralai/Mixtral-8x7B-Instruct-v0.1 | ✔ | ✔ | ✔ | ✔ | -| Mistral-7B | mistralai/Mistral-7B-Instruct-v0.3 | ✔ | ✔ | ✔ | ✔ | -| Falcon-180B | tiiuae/falcon-180B-chat | | ✔ | | ✔ | -| Qwen2-72B | Qwen/Qwen2-72B-Instruct | | ✔ | | ✔ | -| Starcoder2-3b | bigcode/starcoder2-3b | ✔ | ✔ | ✔ | | -| Starcoder2-15b | bigcode/starcoder2-15b | ✔ | ✔ | ✔ | | -| Starcoder | bigcode/starcoder | ✔ | ✔ | ✔ | ✔ | -| Gemma-7b | google/gemma-7b-it | ✔ | ✔ | ✔ | ✔ | -| Llava-v1.6-Mistral-7B | llava-hf/llava-v1.6-mistral-7b-hf | ✔ | ✔ | ✔ | ✔ | - -To run any of these models: +For example, to run Llama3.1-8B, you can use the following command: ```bash -model=MODEL_ID_THAT_YOU_WANT_TO_RUN +model=meta-llama/Meta-Llama-3.1-8B-Instruct volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run hf_token=YOUR_ACCESS_TOKEN @@ -87,7 +65,7 @@ The validated docker commands can be found in the [examples/docker_commands fold ### How to Enable Multi-Card Inference (Sharding) -TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. +TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. This is recommended to run large models and to speed up inference. For example, on a machine with 8 Gaudi cards, you can run: @@ -270,6 +248,45 @@ Note: Model warmup can take several minutes, especially for FP8 inference. For f This section contains reference information about the Gaudi backend. +### Supported Models + +Text Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi. + +**Large Language Models (LLMs)** +- [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +- [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) +- [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) +- [Llama3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) +- [LLama3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) +- [LLama3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) +- [CodeLlama-13B](https://huggingface.co/codellama/CodeLlama-13b-hf) +- [Opt-125m](https://huggingface.co/facebook/opt-125m) +- [OpenAI-gpt2](https://huggingface.co/openai-community/gpt2) +- [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) +- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) +- [Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B-Instruct) +- [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B-Instruct) +- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) +- [Gemma-7b](https://huggingface.co/google/gemma-7b-it) +- [Starcoder2-3b](https://huggingface.co/bigcode/starcoder2-3b) +- [Starcoder2-15b](https://huggingface.co/bigcode/starcoder2-15b) +- [Starcoder](https://huggingface.co/bigcode/starcoder) +- [falcon-7b-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct) +- [Falcon-180B](https://huggingface.co/tiiuae/falcon-180B-chat) +- [GPT-2](https://huggingface.co/openai-community/gpt2) +- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) + +**Vision-Language Models (VLMs)** +- [LLaVA-v1.6-Mistral-7B](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) +- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) +- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) +- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) +- [Idefics 2.5](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) +- [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) +- [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) + +We also support on a best effort basis models with different parameters count that use the same model architecture but those models were not tested. For example, the gaudi backend supports `meta-llama/Llama-3.2-1B` as the architecture is the standard llama3 architecture. If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues). + ### Environment Variables The following table contains the environment variables that can be used to configure the Gaudi backend: