feat(server): use float16

This commit is contained in:
OlivierDehaene 2023-05-09 18:36:20 +02:00
parent ad66f6ef9a
commit 090352d965
8 changed files with 8 additions and 8 deletions

View File

@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -452,7 +452,7 @@ class CausalLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -54,7 +54,7 @@ class OPTSharded(OPT):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32