2024-04-23 21:04:44 +00:00
|
|
|
def load_text_model(prefix, config, weights, name=None):
|
|
|
|
if config.model_type == "llama":
|
|
|
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
|
|
FlashLlamaForCausalLM,
|
|
|
|
)
|
|
|
|
|
|
|
|
return FlashLlamaForCausalLM(prefix, config, weights)
|
|
|
|
elif config.model_type == "mistral":
|
|
|
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
|
|
FlashMistralForCausalLM,
|
|
|
|
)
|
|
|
|
|
|
|
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"{prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 04:58:47 +00:00
|
|
|
elif config.model_type == "gemma":
|
|
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
|
|
FlashGemmaForCausalLM,
|
|
|
|
)
|
|
|
|
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
2024-12-06 19:41:49 +00:00
|
|
|
elif config.model_type == "gemma2":
|
|
|
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
|
|
|
FlashGemma2ForCausalLM,
|
|
|
|
)
|
|
|
|
|
|
|
|
return FlashGemma2ForCausalLM(prefix, config, weights)
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"{prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 04:58:47 +00:00
|
|
|
elif config.model_type == "paligemma":
|
|
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
|
|
FlashGemmaForCausalLM,
|
|
|
|
)
|
|
|
|
|
|
|
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
2024-04-23 21:04:44 +00:00
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
def load_vision_model(prefix, config, weights):
|
|
|
|
if config.model_type == "clip_vision_model":
|
|
|
|
from text_generation_server.models.custom_modeling.clip import (
|
|
|
|
CLIPVisionTransformer,
|
|
|
|
)
|
|
|
|
|
|
|
|
return CLIPVisionTransformer(
|
|
|
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
|
|
|
)
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"{prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 04:58:47 +00:00
|
|
|
if config.model_type == "siglip_vision_model":
|
|
|
|
from text_generation_server.models.custom_modeling.siglip import (
|
|
|
|
SiglipVisionTransformer,
|
|
|
|
)
|
|
|
|
|
|
|
|
return SiglipVisionTransformer(
|
2024-07-26 14:29:09 +00:00
|
|
|
prefix="vision_tower.vision_model", config=config, weights=weights
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"{prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 04:58:47 +00:00
|
|
|
)
|
2024-04-23 21:04:44 +00:00
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|