mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
enable bfloat16 for cpu
if there's no cuda. disable custom kernels Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
c8a01d7591
commit
c44fce6c09
@ -153,7 +153,7 @@ def get_model(
|
||||
)
|
||||
elif model_type == "mpt":
|
||||
return MPTSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
elif model_type == "gpt_neox":
|
||||
|
@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -492,7 +492,7 @@ class CausalLM(Model):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -40,7 +40,7 @@ from text_generation_server.utils.layers import (
|
||||
)
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
try:
|
||||
from custom_kernels import fused_bloom_attention_cuda
|
||||
|
||||
|
@ -49,7 +49,7 @@ from text_generation_server.utils.layers import (
|
||||
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
try:
|
||||
from custom_kernels import fused_attention_cuda
|
||||
|
||||
|
@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
self.device, self.dtype = device, dtype
|
||||
|
||||
config = IdeficsConfig.from_pretrained(
|
||||
|
@ -560,7 +560,7 @@ class IdeficsCausalLM(Model):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -43,14 +43,16 @@ class MPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("MPTSharded is only available on GPU")
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -31,7 +31,7 @@ class OPTSharded(CausalLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -23,7 +23,7 @@ class RW(CausalLM):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -30,7 +30,7 @@ class SantaCoder(CausalLM):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -541,7 +541,7 @@ class Seq2SeqLM(Model):
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
|
@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
|
Loading…
Reference in New Issue
Block a user