feat(server): use float16

This commit is contained in:
OlivierDehaene 2023-05-09 18:36:20 +02:00
parent ad66f6ef9a
commit 090352d965
8 changed files with 8 additions and 8 deletions

View File

@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -452,7 +452,7 @@ class CausalLM(Model):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -54,7 +54,7 @@ class OPTSharded(OPT):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32