mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
enable bfloat16 for cpu
if there's no cuda. disable custom kernels Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
c8a01d7591
commit
c44fce6c09
@ -153,7 +153,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == "mpt":
|
elif model_type == "mpt":
|
||||||
return MPTSharded(
|
return MPTSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "gpt_neox":
|
elif model_type == "gpt_neox":
|
||||||
|
@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -492,7 +492,7 @@ class CausalLM(Model):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -40,7 +40,7 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||||
try:
|
try:
|
||||||
from custom_kernels import fused_bloom_attention_cuda
|
from custom_kernels import fused_bloom_attention_cuda
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ from text_generation_server.utils.layers import (
|
|||||||
|
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||||
try:
|
try:
|
||||||
from custom_kernels import fused_attention_cuda
|
from custom_kernels import fused_attention_cuda
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
self.device, self.dtype = device, dtype
|
self.device, self.dtype = device, dtype
|
||||||
|
|
||||||
config = IdeficsConfig.from_pretrained(
|
config = IdeficsConfig.from_pretrained(
|
||||||
|
@ -560,7 +560,7 @@ class IdeficsCausalLM(Model):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -43,14 +43,16 @@ class MPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("MPTSharded is only available on GPU")
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -31,7 +31,7 @@ class OPTSharded(CausalLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -23,7 +23,7 @@ class RW(CausalLM):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -30,7 +30,7 @@ class SantaCoder(CausalLM):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -541,7 +541,7 @@ class Seq2SeqLM(Model):
|
|||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user