fit for baichuan models

This commit is contained in:
xiaoyuze 2023-09-05 15:57:32 +08:00
parent 033230ae66
commit 2a1f306e26
3 changed files with 51 additions and 10 deletions

View File

@ -182,7 +182,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "llama":
elif model_type == "llama" or model_type == "baichuan":
if FLASH_ATTENTION:
return FlashLlama(
model_id,

View File

@ -208,13 +208,37 @@ class FlashLlamaAttention(torch.nn.Module):
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
try:
def load_packed(cls, config, prefix: str, weights, bias: bool):
packed_tensor = weights.get_tensor(prefix, to_device=False)
#QKV
total_size = packed_tensor.size()[0]
single_size = total_size // 3
q_tensor = packed_tensor[0: single_size, :]
k_tensor = packed_tensor[single_size: single_size * 2, :]
v_tensor = packed_tensor[single_size * 2 : total_size, :]
q_weight = weights.get_tensor_shard(q_tensor, dim = 0)
k_weight = weights.get_tensor_shard(k_tensor, dim = 0)
v_weight = weights.get_tensor_shard(v_tensor, dim = 0)
weight = torch.concat([q_weight, k_weight, v_weight], dim = 0)
return cls(get_linear(weight, None, config.quantize))
self.query_key_value = load_packed(TensorParallelColumnLinear,
config,
prefix=f"{prefix}.W_pack.weight",
weights=weights,
bias=False,
)
except:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",

View File

@ -62,7 +62,7 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
def get_tensor(self, tensor_name: str, to_device = True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
@ -70,7 +70,8 @@ class Weights:
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_partial_sharded(self, tensor_name: str, dim: int):
@ -137,6 +138,22 @@ class Weights:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
block_size = var.size()[dim] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = var[start:stop]
elif dim == 1:
tensor = var[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":