fix load_weights

This commit is contained in:
OlivierDehaene 2023-04-11 20:10:33 +02:00
parent 622daeb0c8
commit 7c281908cf
9 changed files with 47 additions and 25 deletions

View File

@ -19,8 +19,10 @@ from text_generation_server.models.t5 import T5Sharded
try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded
from text_generation_server.models.flash_santacoder import FlashSantacoder, FlashSantacoderSharded
from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded,
)
FLASH_ATTENTION = torch.cuda.is_available()
except ImportError:
@ -83,7 +85,9 @@ def get_model(
if "bigcode" in model_id:
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder"))
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision=revision)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder

View File

@ -88,10 +88,11 @@ class BLOOMSharded(BLOOM):
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
@ -104,6 +105,7 @@ class BLOOMSharded(BLOOM):
filenames: List[str],
quantize: bool,
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
@ -153,7 +155,7 @@ class BLOOMSharded(BLOOM):
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
if quantize:
if not HAS_BITS_AND_BYTES:

View File

@ -209,7 +209,9 @@ class FlashMQAttention(torch.nn.Module):
self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_proj = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group,
hidden_size,
hidden_size,
process_group=process_group,
)
def forward(

View File

@ -64,11 +64,12 @@ class FlashNeoXSharded(FlashNeoX):
model,
filenames,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
)
model.post_load_weights()
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
@ -80,6 +81,7 @@ class FlashNeoXSharded(FlashNeoX):
model,
filenames: List[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
@ -138,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor

View File

@ -58,11 +58,7 @@ class FlashSantacoder(FlashCausalLM):
model = FlashSantacoderForCausalLM(config)
self.load_weights(
model,
filenames,
device,
dtype,
config.architectures[0].startswith("GPT2")
model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
)
self.model = model.eval()
@ -77,7 +73,7 @@ class FlashSantacoder(FlashCausalLM):
filenames: List[Path],
device: torch.device,
dtype: torch.dtype,
transpose: bool
transpose: bool,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
@ -179,7 +175,9 @@ class FlashSantacoderSharded(FlashSantacoder):
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
if quantize:
raise NotImplementedError("FlashSantacoderSharded does not support quantization")
raise NotImplementedError(
"FlashSantacoderSharded does not support quantization"
)
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
@ -247,7 +245,9 @@ class FlashSantacoderSharded(FlashSantacoder):
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
tensor = (
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
)
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
@ -255,7 +255,11 @@ class FlashSantacoderSharded(FlashSantacoder):
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop] if dim == 0 else slice_[:, start:stop]
tensor = (
slice_[start:stop]
if dim == 0
else slice_[:, start:stop]
)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.

View File

@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
filenames: List[str],
quantize: bool,
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
if quantize:
if not HAS_BITS_AND_BYTES:

View File

@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
filenames: List[str],
quantize: bool,
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
if quantize:
if not HAS_BITS_AND_BYTES:

View File

@ -80,10 +80,11 @@ class OPTSharded(OPT):
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
@ -96,6 +97,7 @@ class OPTSharded(OPT):
filenames: List[str],
quantize: bool,
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
@ -146,7 +148,7 @@ class OPTSharded(OPT):
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
if quantize:
if not HAS_BITS_AND_BYTES:

View File

@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer,
@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
filenames: List[str],
quantize: bool,
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
tensor = tensor.contiguous().to(dtype)
if quantize:
if not HAS_BITS_AND_BYTES: