mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Defaulting to bf16.
This commit is contained in:
parent
defc477a03
commit
8cda9ca2f7
@ -10,7 +10,9 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
||||||
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
|
from text_generation_server.models.custom_modeling.idefics_processing import (
|
||||||
|
IdeficsProcessor,
|
||||||
|
)
|
||||||
from transformers import LlamaTokenizerFast
|
from transformers import LlamaTokenizerFast
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
IdeficsForVisionText2Text,
|
IdeficsForVisionText2Text,
|
||||||
@ -35,7 +37,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
# 9b seems to work correctly enough in float16, but 80b seems
|
||||||
|
# to be really saturating for f16.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
Loading…
Reference in New Issue
Block a user