mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
feat(test): add more models to integration tests
This commit is contained in:
parent
064c7bd03c
commit
32b00e039b
@ -54,3 +54,9 @@ run-integration-tests:
|
||||
DOCKER_VOLUME=${root_dir}/data \
|
||||
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
|
||||
|
||||
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
|
||||
capture-expected-outputs-for-integration-tests:
|
||||
DOCKER_VOLUME=${root_dir}/data \
|
||||
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py
|
||||
|
@ -109,6 +109,11 @@ Then run the following command to run the integration tests:
|
||||
make -C backends/gaudi run-integration-tests
|
||||
```
|
||||
|
||||
To capture the expected outputs for the integration tests, you can run the following command:
|
||||
```bash
|
||||
make -C backends/gaudi capture-expected-outputs-for-integration-tests
|
||||
```
|
||||
|
||||
#### How the integration tests works
|
||||
The integration tests works as follows:
|
||||
|
||||
|
@ -0,0 +1,85 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Generator
|
||||
|
||||
import pytest
|
||||
from test_model import TEST_CONFIGS
|
||||
|
||||
UNKNOWN_CONFIGS = {
|
||||
name: config
|
||||
for name, config in TEST_CONFIGS.items()
|
||||
if config["expected_greedy_output"] == "unknown"
|
||||
or config["expected_batch_output"] == "unknown"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
|
||||
def test_config(request) -> Dict[str, Any]:
|
||||
"""Fixture that provides model configurations for testing."""
|
||||
test_config = UNKNOWN_CONFIGS[request.param]
|
||||
test_config["test_name"] = request.param
|
||||
return test_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_name(test_config):
|
||||
yield test_config["test_name"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tgi_service(launcher, test_config, test_name) -> Generator:
|
||||
"""Fixture that provides a TGI service for testing."""
|
||||
with launcher(test_config["model_id"], test_name) as service:
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capture_expected_outputs(tgi_service, test_config, test_name):
|
||||
"""Test that captures expected outputs for models with unknown outputs."""
|
||||
print(f"Testing {test_name} with {test_config['model_id']}")
|
||||
|
||||
# Wait for service to be ready
|
||||
await tgi_service.health(1000)
|
||||
client = tgi_service.client
|
||||
|
||||
# Test single request (greedy)
|
||||
print("Testing single request...")
|
||||
response = await client.generate(
|
||||
test_config["input"],
|
||||
max_new_tokens=32,
|
||||
)
|
||||
greedy_output = response.generated_text
|
||||
|
||||
# Test multiple requests (batch)
|
||||
print("Testing batch requests...")
|
||||
responses = []
|
||||
for _ in range(4):
|
||||
response = await client.generate(
|
||||
test_config["input"],
|
||||
max_new_tokens=32,
|
||||
)
|
||||
responses.append(response.generated_text)
|
||||
|
||||
# Store results in a JSON file
|
||||
output_file = "server/integration-tests/expected_outputs.json"
|
||||
results = {}
|
||||
|
||||
# Try to load existing results if file exists
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Update results for this model
|
||||
results[test_name] = {
|
||||
"model_id": test_config["model_id"],
|
||||
"input": test_config["input"],
|
||||
"greedy_output": greedy_output,
|
||||
"batch_outputs": responses,
|
||||
"args": test_config["args"],
|
||||
}
|
||||
|
||||
# Save updated results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"\nResults for {test_name} saved to {output_file}")
|
@ -8,6 +8,7 @@ import threading
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List
|
||||
import socket
|
||||
|
||||
import docker
|
||||
import pytest
|
||||
@ -73,12 +74,12 @@ logger.add(
|
||||
# signal.signal(signal.SIGTERM, cleanup_handler)
|
||||
|
||||
|
||||
def stream_container_logs(container):
|
||||
def stream_container_logs(container, test_name):
|
||||
"""Stream container logs in a separate thread."""
|
||||
try:
|
||||
for log in container.logs(stream=True, follow=True):
|
||||
print(
|
||||
"[TGI Server Logs] " + log.decode("utf-8"),
|
||||
f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
|
||||
end="",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
@ -178,11 +179,21 @@ def launcher(data_volume):
|
||||
logger.info(
|
||||
f"Starting docker launcher for model {model_id} and test {test_name}"
|
||||
)
|
||||
port = 8080
|
||||
|
||||
# Get a random available port
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
port = get_free_port()
|
||||
logger.debug(f"Using port {port}")
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
container_name = f"tgi-gaudi-test-{test_name}"
|
||||
container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
|
||||
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
@ -225,7 +236,7 @@ def launcher(data_volume):
|
||||
environment=env,
|
||||
detach=True,
|
||||
volumes=volumes,
|
||||
ports={"80/tcp": 8080},
|
||||
ports={"80/tcp": port},
|
||||
**HABANA_RUN_ARGS,
|
||||
)
|
||||
|
||||
@ -234,7 +245,7 @@ def launcher(data_volume):
|
||||
# Start log streaming in a background thread
|
||||
log_thread = threading.Thread(
|
||||
target=stream_container_logs,
|
||||
args=(container,),
|
||||
args=(container, test_name),
|
||||
daemon=True, # This ensures the thread will be killed when the main program exits
|
||||
)
|
||||
log_thread.start()
|
||||
|
@ -4,8 +4,17 @@ from text_generation import AsyncClient
|
||||
import pytest
|
||||
from Levenshtein import distance as levenshtein_distance
|
||||
|
||||
# Model that are not working but should... :(
|
||||
# - google/gemma-2-2b-it
|
||||
# - google/flan-t5-large
|
||||
# - codellama/CodeLlama-13b-hf
|
||||
# - ibm-granite/granite-3.0-8b-instruct
|
||||
# - microsoft/Phi-3.5-MoE-instruct
|
||||
# - microsoft/Phi-3-mini-4k-instruct
|
||||
|
||||
# The config in args is not optimized for speed but only check that inference is working for the different models architectures
|
||||
TEST_CONFIGS = {
|
||||
"llama3-8b-shared-gaudi1": {
|
||||
"meta-llama/Llama-3.1-8B-Instruct-shared": {
|
||||
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||||
@ -25,11 +34,10 @@ TEST_CONFIGS = {
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"llama3-8b-gaudi1": {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": {
|
||||
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||||
# "expected_sampling_output": " Part 2: Backpropagation\nIn my last post , I introduced the concept of deep learning and covered some high-level topics, including feedforward neural networks.",
|
||||
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||||
"env_config": {},
|
||||
"args": [
|
||||
@ -43,8 +51,160 @@ TEST_CONFIGS = {
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"meta-llama/Llama-2-7b-chat-hf": {
|
||||
"model_id": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
||||
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
"--max-batch-prefill-tokens",
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"mistralai/Mistral-7B-Instruct-v0.3": {
|
||||
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
"--max-batch-prefill-tokens",
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"bigcode/starcoder2-3b": {
|
||||
"model_id": "bigcode/starcoder2-3b",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
"--max-batch-prefill-tokens",
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"google/gemma-7b-it": {
|
||||
"model_id": "google/gemma-7b-it",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
"--max-batch-prefill-tokens",
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"Qwen/Qwen2-0.5B-Instruct": {
|
||||
"model_id": "Qwen/Qwen2-0.5B-Instruct",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
||||
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
"--max-batch-prefill-tokens",
|
||||
"2048",
|
||||
],
|
||||
},
|
||||
"tiiuae/falcon-7b-instruct": {
|
||||
"model_id": "tiiuae/falcon-7b-instruct",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
||||
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
],
|
||||
},
|
||||
"microsoft/phi-1_5": {
|
||||
"model_id": "microsoft/phi-1_5",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
||||
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
],
|
||||
},
|
||||
"openai-community/gpt2": {
|
||||
"model_id": "openai-community/gpt2",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
||||
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
],
|
||||
},
|
||||
"facebook/opt-125m": {
|
||||
"model_id": "facebook/opt-125m",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
|
||||
"expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
],
|
||||
},
|
||||
"EleutherAI/gpt-j-6b": {
|
||||
"model_id": "EleutherAI/gpt-j-6b",
|
||||
"input": "What is Deep Learning?",
|
||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
||||
"args": [
|
||||
"--max-input-tokens",
|
||||
"512",
|
||||
"--max-total-tokens",
|
||||
"1024",
|
||||
"--max-batch-size",
|
||||
"4",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
print(f"Testing {len(TEST_CONFIGS)} models")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
|
||||
def test_config(request) -> Dict[str, Any]:
|
||||
@ -102,22 +262,6 @@ async def test_model_single_request(
|
||||
assert response.details.generated_tokens == 32
|
||||
assert response.generated_text == expected_outputs["greedy"]
|
||||
|
||||
# TO FIX: Add sampling tests later, the gaudi backend seems to be flaky with the sampling even with a fixed seed
|
||||
|
||||
# Sampling
|
||||
# response = await tgi_client.generate(
|
||||
# "What is Deep Learning?",
|
||||
# do_sample=True,
|
||||
# top_k=50,
|
||||
# top_p=0.9,
|
||||
# repetition_penalty=1.2,
|
||||
# max_new_tokens=100,
|
||||
# seed=42,
|
||||
# decoder_input_details=True,
|
||||
# )
|
||||
|
||||
# assert expected_outputs["sampling"] in response.generated_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_multiple_requests(
|
||||
|
@ -41,34 +41,12 @@ curl 127.0.0.1:8080/generate \
|
||||
|
||||
## How-to Guides
|
||||
|
||||
### How to Run Specific Models
|
||||
You can view the full list of supported models in the [Supported Models](https://huggingface.co/docs/text-generation-inference/backends/gaudi#supported-models) section.
|
||||
|
||||
The following models have been validated on Gaudi2:
|
||||
|
||||
| Model | Model ID | BF16 | | FP8 | |
|
||||
|-----------------------|----------------------------------------|-------------|------------|-------------|------------|
|
||||
| | | Single Card | Multi-Card | Single Card | Multi-Card |
|
||||
| Llama2-7B | meta-llama/Llama-2-7b-chat-hf | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llama2-70B | meta-llama/Llama-2-70b-chat-hf | | ✔ | | ✔ |
|
||||
| Llama3-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llama3-70B | meta-llama/Meta-Llama-3-70B-Instruct | | ✔ | | ✔ |
|
||||
| Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B-Instruct | | ✔ | | ✔ |
|
||||
| CodeLlama-13B | codellama/CodeLlama-13b-hf | ✔ | ✔ | ✔ | ✔ |
|
||||
| Mixtral-8x7B | mistralai/Mixtral-8x7B-Instruct-v0.1 | ✔ | ✔ | ✔ | ✔ |
|
||||
| Mistral-7B | mistralai/Mistral-7B-Instruct-v0.3 | ✔ | ✔ | ✔ | ✔ |
|
||||
| Falcon-180B | tiiuae/falcon-180B-chat | | ✔ | | ✔ |
|
||||
| Qwen2-72B | Qwen/Qwen2-72B-Instruct | | ✔ | | ✔ |
|
||||
| Starcoder2-3b | bigcode/starcoder2-3b | ✔ | ✔ | ✔ | |
|
||||
| Starcoder2-15b | bigcode/starcoder2-15b | ✔ | ✔ | ✔ | |
|
||||
| Starcoder | bigcode/starcoder | ✔ | ✔ | ✔ | ✔ |
|
||||
| Gemma-7b | google/gemma-7b-it | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llava-v1.6-Mistral-7B | llava-hf/llava-v1.6-mistral-7b-hf | ✔ | ✔ | ✔ | ✔ |
|
||||
|
||||
To run any of these models:
|
||||
For example, to run Llama3.1-8B, you can use the following command:
|
||||
|
||||
```bash
|
||||
model=MODEL_ID_THAT_YOU_WANT_TO_RUN
|
||||
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
|
||||
@ -87,7 +65,7 @@ The validated docker commands can be found in the [examples/docker_commands fold
|
||||
|
||||
### How to Enable Multi-Card Inference (Sharding)
|
||||
|
||||
TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards.
|
||||
TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. This is recommended to run large models and to speed up inference.
|
||||
|
||||
For example, on a machine with 8 Gaudi cards, you can run:
|
||||
|
||||
@ -270,6 +248,45 @@ Note: Model warmup can take several minutes, especially for FP8 inference. For f
|
||||
|
||||
This section contains reference information about the Gaudi backend.
|
||||
|
||||
### Supported Models
|
||||
|
||||
Text Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi.
|
||||
|
||||
**Large Language Models (LLMs)**
|
||||
- [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
|
||||
- [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
|
||||
- [Llama3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
|
||||
- [LLama3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
|
||||
- [LLama3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
|
||||
- [CodeLlama-13B](https://huggingface.co/codellama/CodeLlama-13b-hf)
|
||||
- [Opt-125m](https://huggingface.co/facebook/opt-125m)
|
||||
- [OpenAI-gpt2](https://huggingface.co/openai-community/gpt2)
|
||||
- [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
|
||||
- [Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B-Instruct)
|
||||
- [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
|
||||
- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5)
|
||||
- [Gemma-7b](https://huggingface.co/google/gemma-7b-it)
|
||||
- [Starcoder2-3b](https://huggingface.co/bigcode/starcoder2-3b)
|
||||
- [Starcoder2-15b](https://huggingface.co/bigcode/starcoder2-15b)
|
||||
- [Starcoder](https://huggingface.co/bigcode/starcoder)
|
||||
- [falcon-7b-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct)
|
||||
- [Falcon-180B](https://huggingface.co/tiiuae/falcon-180B-chat)
|
||||
- [GPT-2](https://huggingface.co/openai-community/gpt2)
|
||||
- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
|
||||
**Vision-Language Models (VLMs)**
|
||||
- [LLaVA-v1.6-Mistral-7B](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
|
||||
- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
|
||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b)
|
||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)
|
||||
- [Idefics 2.5](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
|
||||
- [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
|
||||
- [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)
|
||||
|
||||
We also support on a best effort basis models with different parameters count that use the same model architecture but those models were not tested. For example, the gaudi backend supports `meta-llama/Llama-3.2-1B` as the architecture is the standard llama3 architecture. If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following table contains the environment variables that can be used to configure the Gaudi backend:
|
||||
|
Loading…
Reference in New Issue
Block a user