Updating all models.

This commit is contained in:
Nicolas Patry 2023-05-04 12:31:51 +02:00
parent 1185f66205
commit e2d167256a
10 changed files with 85 additions and 34 deletions

View File

@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOM(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
)
@ -61,7 +66,10 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -111,7 +119,7 @@ class BLOOMSharded(BLOOM):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -165,7 +173,7 @@ class BLOOMSharded(BLOOM):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -215,9 +223,14 @@ class BLOOMSharded(BLOOM):
return linear
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":

View File

@ -447,7 +447,7 @@ class CausalLM(Model):
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
@ -468,7 +468,7 @@ class CausalLM(Model):
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
).eval()
tokenizer.pad_token_id = (
self.model.config.pad_token_id

View File

@ -92,8 +92,8 @@ class FastLinear(nn.Linear):
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -117,8 +117,12 @@ class FastLinear(nn.Linear):
# Delete reference to data
self.weight = None
self.bias = None
else:
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:

View File

@ -67,8 +67,8 @@ class FastLinear(nn.Linear):
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -92,8 +92,12 @@ class FastLinear(nn.Linear):
# Delete reference to data
self.weight = None
self.bias = None
else:
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:

View File

@ -364,7 +364,7 @@ class FlashCausalLM(Model):
model_cls: Type[PreTrainedModel],
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
self.past_pad = None
@ -382,7 +382,7 @@ class FlashCausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
)
.eval()
.to(device)

View File

@ -193,7 +193,10 @@ class Galactica(OPT):
class GalacticaSharded(Galactica):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -242,7 +245,7 @@ class GalacticaSharded(Galactica):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -297,7 +300,7 @@ class GalacticaSharded(Galactica):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -347,9 +350,14 @@ class GalacticaSharded(Galactica):
return linear
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":

View File

@ -32,7 +32,10 @@ except Exception as e:
class GPTNeoxSharded(CausalLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -81,7 +84,7 @@ class GPTNeoxSharded(CausalLM):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -146,7 +149,7 @@ class GPTNeoxSharded(CausalLM):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -196,9 +199,14 @@ class GPTNeoxSharded(CausalLM):
return linear
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor

View File

@ -14,7 +14,12 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -46,7 +51,7 @@ class SantaCoder(CausalLM):
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
)
.to(device)

View File

@ -501,7 +501,7 @@ class Seq2SeqLM(Model):
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
@ -519,7 +519,7 @@ class Seq2SeqLM(Model):
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
load_in_8bit=quantize == "bitsandbytes",
).eval()
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"

View File

@ -32,7 +32,10 @@ except Exception as e:
class T5Sharded(Seq2SeqLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -81,7 +84,7 @@ class T5Sharded(Seq2SeqLM):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
@ -152,7 +155,7 @@ class T5Sharded(Seq2SeqLM):
tensor = tensor.contiguous().to(dtype)
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -203,8 +206,14 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state)
else:
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor