mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix load_weights
This commit is contained in:
parent
622daeb0c8
commit
7c281908cf
@ -19,8 +19,10 @@ from text_generation_server.models.t5 import T5Sharded
|
||||
try:
|
||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
|
||||
from text_generation_server.models.flash_santacoder import FlashSantacoder, FlashSantacoderSharded
|
||||
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoder,
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
|
||||
FLASH_ATTENTION = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
@ -83,7 +85,9 @@ def get_model(
|
||||
if "bigcode" in model_id:
|
||||
if sharded:
|
||||
if not FLASH_ATTENTION:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder"))
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
return FlashSantacoderSharded(model_id, revision=revision)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
|
@ -88,10 +88,11 @@ class BLOOMSharded(BLOOM):
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -104,6 +105,7 @@ class BLOOMSharded(BLOOM):
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -153,7 +155,7 @@ class BLOOMSharded(BLOOM):
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
|
@ -209,7 +209,9 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
|
||||
self.c_proj = TensorParallelRowLinear(
|
||||
hidden_size, hidden_size, process_group=process_group,
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -64,11 +64,12 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
model,
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
model.post_load_weights()
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -80,6 +81,7 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
model,
|
||||
filenames: List[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -138,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
|
@ -58,11 +58,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||
model = FlashSantacoderForCausalLM(config)
|
||||
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
config.architectures[0].startswith("GPT2")
|
||||
model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
|
||||
)
|
||||
self.model = model.eval()
|
||||
|
||||
@ -77,7 +73,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||
filenames: List[Path],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
transpose: bool
|
||||
transpose: bool,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
@ -179,7 +175,9 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashSantacoderSharded does not support quantization")
|
||||
raise NotImplementedError(
|
||||
"FlashSantacoderSharded does not support quantization"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
@ -247,7 +245,9 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
|
||||
tensor = (
|
||||
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
|
||||
)
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
dim = 0 if transpose else 1
|
||||
@ -255,7 +255,11 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
|
||||
tensor = (
|
||||
slice_[start:stop]
|
||||
if dim == 0
|
||||
else slice_[:, start:stop]
|
||||
)
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
|
@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
|
@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
|
@ -80,10 +80,11 @@ class OPTSharded(OPT):
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -96,6 +97,7 @@ class OPTSharded(OPT):
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -146,7 +148,7 @@ class OPTSharded(OPT):
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
|
@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
|
Loading…
Reference in New Issue
Block a user