fit for baichuan models

This commit is contained in:
xiaoyuze 2023-09-05 15:57:32 +08:00
parent 033230ae66
commit 2a1f306e26
3 changed files with 51 additions and 10 deletions

View File

@ -182,7 +182,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "llama": elif model_type == "llama" or model_type == "baichuan":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashLlama(
model_id, model_id,

View File

@ -208,6 +208,29 @@ class FlashLlamaAttention(torch.nn.Module):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights) self.query_key_value = _load_gqa(config, prefix, weights)
else: else:
try:
def load_packed(cls, config, prefix: str, weights, bias: bool):
packed_tensor = weights.get_tensor(prefix, to_device=False)
#QKV
total_size = packed_tensor.size()[0]
single_size = total_size // 3
q_tensor = packed_tensor[0: single_size, :]
k_tensor = packed_tensor[single_size: single_size * 2, :]
v_tensor = packed_tensor[single_size * 2 : total_size, :]
q_weight = weights.get_tensor_shard(q_tensor, dim = 0)
k_weight = weights.get_tensor_shard(k_tensor, dim = 0)
v_weight = weights.get_tensor_shard(v_tensor, dim = 0)
weight = torch.concat([q_weight, k_weight, v_weight], dim = 0)
return cls(get_linear(weight, None, config.quantize))
self.query_key_value = load_packed(TensorParallelColumnLinear,
config,
prefix=f"{prefix}.W_pack.weight",
weights=weights,
bias=False,
)
except:
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -215,6 +238,7 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",

View File

@ -62,7 +62,7 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str): def get_tensor(self, tensor_name: str, to_device = True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
@ -70,6 +70,7 @@ class Weights:
# u4 which are disguised as int32 # u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]: if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
@ -138,6 +139,22 @@ class Weights:
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
return weight return weight
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
block_size = var.size()[dim] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = var[start:stop]
elif dim == 1:
tensor = var[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True use_exllama = True