mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fit for baichuan models
This commit is contained in:
parent
033230ae66
commit
2a1f306e26
@ -182,7 +182,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "llama":
|
||||
elif model_type == "llama" or model_type == "baichuan":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
|
@ -208,13 +208,37 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
try:
|
||||
def load_packed(cls, config, prefix: str, weights, bias: bool):
|
||||
packed_tensor = weights.get_tensor(prefix, to_device=False)
|
||||
#QKV
|
||||
total_size = packed_tensor.size()[0]
|
||||
single_size = total_size // 3
|
||||
q_tensor = packed_tensor[0: single_size, :]
|
||||
k_tensor = packed_tensor[single_size: single_size * 2, :]
|
||||
v_tensor = packed_tensor[single_size * 2 : total_size, :]
|
||||
q_weight = weights.get_tensor_shard(q_tensor, dim = 0)
|
||||
k_weight = weights.get_tensor_shard(k_tensor, dim = 0)
|
||||
v_weight = weights.get_tensor_shard(v_tensor, dim = 0)
|
||||
weight = torch.concat([q_weight, k_weight, v_weight], dim = 0)
|
||||
return cls(get_linear(weight, None, config.quantize))
|
||||
|
||||
self.query_key_value = load_packed(TensorParallelColumnLinear,
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack.weight",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
except:
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
|
@ -62,7 +62,7 @@ class Weights:
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str):
|
||||
def get_tensor(self, tensor_name: str, to_device = True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
@ -70,7 +70,8 @@ class Weights:
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
@ -137,6 +138,22 @@ class Weights:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
block_size = var.size()[dim] // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
if dim == 0:
|
||||
tensor = var[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = var[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
|
Loading…
Reference in New Issue
Block a user