diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 932ab32e..e6fe1372 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -182,7 +182,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama": + elif model_type == "llama" or model_type == "baichuan": if FLASH_ATTENTION: return FlashLlama( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f0e1236d..41db4fa7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -208,13 +208,37 @@ class FlashLlamaAttention(torch.nn.Module): if config.num_attention_heads != config.num_key_value_heads: self.query_key_value = _load_gqa(config, prefix, weights) else: - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) + try: + def load_packed(cls, config, prefix: str, weights, bias: bool): + packed_tensor = weights.get_tensor(prefix, to_device=False) + #QKV + total_size = packed_tensor.size()[0] + single_size = total_size // 3 + q_tensor = packed_tensor[0: single_size, :] + k_tensor = packed_tensor[single_size: single_size * 2, :] + v_tensor = packed_tensor[single_size * 2 : total_size, :] + q_weight = weights.get_tensor_shard(q_tensor, dim = 0) + k_weight = weights.get_tensor_shard(k_tensor, dim = 0) + v_weight = weights.get_tensor_shard(v_tensor, dim = 0) + weight = torch.concat([q_weight, k_weight, v_weight], dim = 0) + return cls(get_linear(weight, None, config.quantize)) + + self.query_key_value = load_packed(TensorParallelColumnLinear, + config, + prefix=f"{prefix}.W_pack.weight", + weights=weights, + bias=False, + ) + + except: + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index ef662ce1..25d7ff1d 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -62,7 +62,7 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str): + def get_tensor(self, tensor_name: str, to_device = True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -70,7 +70,8 @@ class Weights: # u4 which are disguised as int32 if tensor.dtype not in [torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor def get_partial_sharded(self, tensor_name: str, dim: int): @@ -137,6 +138,22 @@ class Weights: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight + + def get_tensor_shard(self, var, dim): + world_size = self.process_group.size() + rank = self.process_group.rank() + block_size = var.size()[dim] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + if dim == 0: + tensor = var[start:stop] + elif dim == 1: + tensor = var[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq":