Updating all models.

This commit is contained in:
Nicolas Patry 2023-05-04 12:31:51 +02:00
parent 1185f66205
commit e2d167256a
10 changed files with 85 additions and 34 deletions

View File

@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOM(CausalLM): class BLOOM(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
super(BLOOM, self).__init__( super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
) )
@ -61,7 +66,10 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM): class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -111,7 +119,7 @@ class BLOOMSharded(BLOOM):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: Optional[str],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
rank: int, rank: int,
@ -165,7 +173,7 @@ class BLOOMSharded(BLOOM):
tensor = tensor.contiguous().to(dtype) tensor = tensor.contiguous().to(dtype)
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -215,9 +223,14 @@ class BLOOMSharded(BLOOM):
return linear return linear
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq":
else: raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
if name == "word_embeddings.weight": if name == "word_embeddings.weight":

View File

@ -447,7 +447,7 @@ class CausalLM(Model):
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: bool = False, quantize: Optional[str] = None,
decode_buffer: int = 3, decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -468,7 +468,7 @@ class CausalLM(Model):
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize == "bitsandbytes",
).eval() ).eval()
tokenizer.pad_token_id = ( tokenizer.pad_token_id = (
self.model.config.pad_token_id self.model.config.pad_token_id

View File

@ -92,8 +92,8 @@ class FastLinear(nn.Linear):
self.quantized = False self.quantized = False
self.bnb_linear = None self.bnb_linear = None
def prepare_weights(self, quantize: bool = False): def prepare_weights(self, quantize: Optional[str] = None):
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -117,8 +117,12 @@ class FastLinear(nn.Linear):
# Delete reference to data # Delete reference to data
self.weight = None self.weight = None
self.bias = None self.bias = None
else: elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T) self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized: if self.quantized:

View File

@ -67,8 +67,8 @@ class FastLinear(nn.Linear):
self.quantized = False self.quantized = False
self.bnb_linear = None self.bnb_linear = None
def prepare_weights(self, quantize: bool = False): def prepare_weights(self, quantize: Optional[str] = None):
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -92,8 +92,12 @@ class FastLinear(nn.Linear):
# Delete reference to data # Delete reference to data
self.weight = None self.weight = None
self.bias = None self.bias = None
else: elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T) self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized: if self.quantized:

View File

@ -364,7 +364,7 @@ class FlashCausalLM(Model):
model_cls: Type[PreTrainedModel], model_cls: Type[PreTrainedModel],
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: bool = False, quantize: Optional[str] = None,
decode_buffer: int = 3, decode_buffer: int = 3,
): ):
self.past_pad = None self.past_pad = None
@ -382,7 +382,7 @@ class FlashCausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize, load_in_8bit=quantize == "bitsandbytes",
) )
.eval() .eval()
.to(device) .to(device)

View File

@ -193,7 +193,10 @@ class Galactica(OPT):
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -242,7 +245,7 @@ class GalacticaSharded(Galactica):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: Optional[str],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
rank: int, rank: int,
@ -297,7 +300,7 @@ class GalacticaSharded(Galactica):
tensor = tensor.contiguous().to(dtype) tensor = tensor.contiguous().to(dtype)
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -347,9 +350,14 @@ class GalacticaSharded(Galactica):
return linear return linear
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq":
else: raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":

View File

@ -32,7 +32,10 @@ except Exception as e:
class GPTNeoxSharded(CausalLM): class GPTNeoxSharded(CausalLM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -81,7 +84,7 @@ class GPTNeoxSharded(CausalLM):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: Optional[str],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
rank: int, rank: int,
@ -146,7 +149,7 @@ class GPTNeoxSharded(CausalLM):
tensor = tensor.contiguous().to(dtype) tensor = tensor.contiguous().to(dtype)
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -196,9 +199,14 @@ class GPTNeoxSharded(CausalLM):
return linear return linear
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq":
else: raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
module._parameters[param_name] = tensor module._parameters[param_name] = tensor

View File

@ -14,7 +14,12 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM): class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -46,7 +51,7 @@ class SantaCoder(CausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize, load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required trust_remote_code=True, # required
) )
.to(device) .to(device)

View File

@ -501,7 +501,7 @@ class Seq2SeqLM(Model):
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: bool = False, quantize: Optional[str] = None,
decode_buffer: int = 3, decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -519,7 +519,7 @@ class Seq2SeqLM(Model):
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize == "bitsandbytes",
).eval() ).eval()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"

View File

@ -32,7 +32,10 @@ except Exception as e:
class T5Sharded(Seq2SeqLM): class T5Sharded(Seq2SeqLM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -81,7 +84,7 @@ class T5Sharded(Seq2SeqLM):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: Optional[str],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
rank: int, rank: int,
@ -152,7 +155,7 @@ class T5Sharded(Seq2SeqLM):
tensor = tensor.contiguous().to(dtype) tensor = tensor.contiguous().to(dtype)
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -203,8 +206,14 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
else: elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
module._parameters[param_name] = tensor module._parameters[param_name] = tensor