mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Updating all models.
This commit is contained in:
parent
1185f66205
commit
e2d167256a
@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class BLOOM(CausalLM):
|
class BLOOM(CausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
):
|
||||||
super(BLOOM, self).__init__(
|
super(BLOOM, self).__init__(
|
||||||
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
|
||||||
)
|
)
|
||||||
@ -61,7 +66,10 @@ class BLOOM(CausalLM):
|
|||||||
|
|
||||||
class BLOOMSharded(BLOOM):
|
class BLOOMSharded(BLOOM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -111,7 +119,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -165,7 +173,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
tensor = tensor.contiguous().to(dtype)
|
||||||
|
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -215,9 +223,14 @@ class BLOOMSharded(BLOOM):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
elif quantize == "gptq":
|
||||||
else:
|
raise NotImplementedError(
|
||||||
|
"`gptq` is not implemented for now"
|
||||||
|
)
|
||||||
|
elif quantize is None:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
|
@ -447,7 +447,7 @@ class CausalLM(Model):
|
|||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: bool = False,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -468,7 +468,7 @@ class CausalLM(Model):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer.pad_token_id = (
|
tokenizer.pad_token_id = (
|
||||||
self.model.config.pad_token_id
|
self.model.config.pad_token_id
|
||||||
|
@ -92,8 +92,8 @@ class FastLinear(nn.Linear):
|
|||||||
self.quantized = False
|
self.quantized = False
|
||||||
self.bnb_linear = None
|
self.bnb_linear = None
|
||||||
|
|
||||||
def prepare_weights(self, quantize: bool = False):
|
def prepare_weights(self, quantize: Optional[str] = None):
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -117,8 +117,12 @@ class FastLinear(nn.Linear):
|
|||||||
# Delete reference to data
|
# Delete reference to data
|
||||||
self.weight = None
|
self.weight = None
|
||||||
self.bias = None
|
self.bias = None
|
||||||
else:
|
elif quantize == "gptq":
|
||||||
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
|
elif quantize is None:
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
|
@ -67,8 +67,8 @@ class FastLinear(nn.Linear):
|
|||||||
self.quantized = False
|
self.quantized = False
|
||||||
self.bnb_linear = None
|
self.bnb_linear = None
|
||||||
|
|
||||||
def prepare_weights(self, quantize: bool = False):
|
def prepare_weights(self, quantize: Optional[str] = None):
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -92,8 +92,12 @@ class FastLinear(nn.Linear):
|
|||||||
# Delete reference to data
|
# Delete reference to data
|
||||||
self.weight = None
|
self.weight = None
|
||||||
self.bias = None
|
self.bias = None
|
||||||
else:
|
elif quantize == "gptq":
|
||||||
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
|
elif quantize is None:
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
|
@ -364,7 +364,7 @@ class FlashCausalLM(Model):
|
|||||||
model_cls: Type[PreTrainedModel],
|
model_cls: Type[PreTrainedModel],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: bool = False,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
self.past_pad = None
|
self.past_pad = None
|
||||||
@ -382,7 +382,7 @@ class FlashCausalLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
)
|
)
|
||||||
.eval()
|
.eval()
|
||||||
.to(device)
|
.to(device)
|
||||||
|
@ -193,7 +193,10 @@ class Galactica(OPT):
|
|||||||
|
|
||||||
class GalacticaSharded(Galactica):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -242,7 +245,7 @@ class GalacticaSharded(Galactica):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -297,7 +300,7 @@ class GalacticaSharded(Galactica):
|
|||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
tensor = tensor.contiguous().to(dtype)
|
||||||
|
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -347,9 +350,14 @@ class GalacticaSharded(Galactica):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
elif quantize == "gptq":
|
||||||
else:
|
raise NotImplementedError(
|
||||||
|
"`gptq` is not implemented for now"
|
||||||
|
)
|
||||||
|
elif quantize is None:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
|
@ -32,7 +32,10 @@ except Exception as e:
|
|||||||
|
|
||||||
class GPTNeoxSharded(CausalLM):
|
class GPTNeoxSharded(CausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -81,7 +84,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -146,7 +149,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
tensor = tensor.contiguous().to(dtype)
|
||||||
|
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -196,9 +199,14 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
elif quantize == "gptq":
|
||||||
else:
|
raise NotImplementedError(
|
||||||
|
"`gptq` is not implemented for now"
|
||||||
|
)
|
||||||
|
elif quantize is None:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
|
@ -14,7 +14,12 @@ EOD = "<|endoftext|>"
|
|||||||
|
|
||||||
|
|
||||||
class SantaCoder(CausalLM):
|
class SantaCoder(CausalLM):
|
||||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
@ -46,7 +51,7 @@ class SantaCoder(CausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=True, # required
|
trust_remote_code=True, # required
|
||||||
)
|
)
|
||||||
.to(device)
|
.to(device)
|
||||||
|
@ -501,7 +501,7 @@ class Seq2SeqLM(Model):
|
|||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: bool = False,
|
quantize: Optional[str] = None,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -519,7 +519,7 @@ class Seq2SeqLM(Model):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
|
@ -32,7 +32,10 @@ except Exception as e:
|
|||||||
|
|
||||||
class T5Sharded(Seq2SeqLM):
|
class T5Sharded(Seq2SeqLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -81,7 +84,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -152,7 +155,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
tensor = tensor.contiguous().to(dtype)
|
||||||
|
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -203,8 +206,14 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
|
||||||
else:
|
elif quantize == "gptq":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"`gptq` is not implemented for now"
|
||||||
|
)
|
||||||
|
elif quantize is None:
|
||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user