mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fit for baichuan models
This commit is contained in:
parent
033230ae66
commit
2a1f306e26
@ -182,7 +182,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "llama":
|
elif model_type == "llama" or model_type == "baichuan":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashLlama(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -208,6 +208,29 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
self.query_key_value = _load_gqa(config, prefix, weights)
|
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
def load_packed(cls, config, prefix: str, weights, bias: bool):
|
||||||
|
packed_tensor = weights.get_tensor(prefix, to_device=False)
|
||||||
|
#QKV
|
||||||
|
total_size = packed_tensor.size()[0]
|
||||||
|
single_size = total_size // 3
|
||||||
|
q_tensor = packed_tensor[0: single_size, :]
|
||||||
|
k_tensor = packed_tensor[single_size: single_size * 2, :]
|
||||||
|
v_tensor = packed_tensor[single_size * 2 : total_size, :]
|
||||||
|
q_weight = weights.get_tensor_shard(q_tensor, dim = 0)
|
||||||
|
k_weight = weights.get_tensor_shard(k_tensor, dim = 0)
|
||||||
|
v_weight = weights.get_tensor_shard(v_tensor, dim = 0)
|
||||||
|
weight = torch.concat([q_weight, k_weight, v_weight], dim = 0)
|
||||||
|
return cls(get_linear(weight, None, config.quantize))
|
||||||
|
|
||||||
|
self.query_key_value = load_packed(TensorParallelColumnLinear,
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.W_pack.weight",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
except:
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
@ -215,6 +238,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
|
@ -62,7 +62,7 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str):
|
def get_tensor(self, tensor_name: str, to_device = True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
@ -70,6 +70,7 @@ class Weights:
|
|||||||
# u4 which are disguised as int32
|
# u4 which are disguised as int32
|
||||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -138,6 +139,22 @@ class Weights:
|
|||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def get_tensor_shard(self, var, dim):
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
rank = self.process_group.rank()
|
||||||
|
block_size = var.size()[dim] // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
if dim == 0:
|
||||||
|
tensor = var[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = var[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
tensor = tensor.to(device=self.device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
|
Loading…
Reference in New Issue
Block a user