mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Updating all models.
This commit is contained in:
parent
1185f66205
commit
e2d167256a
@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
|
||||
|
||||
class BLOOM(CausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
super(BLOOM, self).__init__(
|
||||
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
||||
)
|
||||
@ -61,7 +66,10 @@ class BLOOM(CausalLM):
|
||||
|
||||
class BLOOMSharded(BLOOM):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -111,7 +119,7 @@ class BLOOMSharded(BLOOM):
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
@ -165,7 +173,7 @@ class BLOOMSharded(BLOOM):
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -215,9 +223,14 @@ class BLOOMSharded(BLOOM):
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError(
|
||||
"`gptq` is not implemented for now"
|
||||
)
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "word_embeddings.weight":
|
||||
|
@ -447,7 +447,7 @@ class CausalLM(Model):
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
@ -468,7 +468,7 @@ class CausalLM(Model):
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
).eval()
|
||||
tokenizer.pad_token_id = (
|
||||
self.model.config.pad_token_id
|
||||
|
@ -92,8 +92,8 @@ class FastLinear(nn.Linear):
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
def prepare_weights(self, quantize: Optional[str] = None):
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -117,8 +117,12 @@ class FastLinear(nn.Linear):
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
|
@ -67,8 +67,8 @@ class FastLinear(nn.Linear):
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
def prepare_weights(self, quantize: Optional[str] = None):
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -92,8 +92,12 @@ class FastLinear(nn.Linear):
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
|
@ -364,7 +364,7 @@ class FlashCausalLM(Model):
|
||||
model_cls: Type[PreTrainedModel],
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
self.past_pad = None
|
||||
@ -382,7 +382,7 @@ class FlashCausalLM(Model):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
|
@ -193,7 +193,10 @@ class Galactica(OPT):
|
||||
|
||||
class GalacticaSharded(Galactica):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -242,7 +245,7 @@ class GalacticaSharded(Galactica):
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
@ -297,7 +300,7 @@ class GalacticaSharded(Galactica):
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -347,9 +350,14 @@ class GalacticaSharded(Galactica):
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError(
|
||||
"`gptq` is not implemented for now"
|
||||
)
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
|
@ -32,7 +32,10 @@ except Exception as e:
|
||||
|
||||
class GPTNeoxSharded(CausalLM):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -81,7 +84,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
@ -146,7 +149,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -196,9 +199,14 @@ class GPTNeoxSharded(CausalLM):
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError(
|
||||
"`gptq` is not implemented for now"
|
||||
)
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
|
@ -14,7 +14,12 @@ EOD = "<|endoftext|>"
|
||||
|
||||
|
||||
class SantaCoder(CausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
@ -46,7 +51,7 @@ class SantaCoder(CausalLM):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=True, # required
|
||||
)
|
||||
.to(device)
|
||||
|
@ -501,7 +501,7 @@ class Seq2SeqLM(Model):
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
@ -519,7 +519,7 @@ class Seq2SeqLM(Model):
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
|
@ -32,7 +32,10 @@ except Exception as e:
|
||||
|
||||
class T5Sharded(Seq2SeqLM):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -81,7 +84,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
@ -152,7 +155,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -203,8 +206,14 @@ class T5Sharded(Seq2SeqLM):
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError(
|
||||
"`gptq` is not implemented for now"
|
||||
)
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
|
Loading…
Reference in New Issue
Block a user