import torch from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex class LayerConcat(torch.nn.Module): """ Apply multiple layers to the input and concatenate their outputs. """ def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1): """ `dim` is the dimension along which layer outputs are concatenated. """ super().__init__() self.layers = layers self.dim = dim def forward(self, x: torch.Tensor): outputs = [layer(x) for layer in self.layers] return torch.cat(outputs, self.dim) class SuperLayer(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear.forward(x) class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): if config.quantize == "exl2": try: # If the piece and LM head embeddings are shared, we have # non-quantized weights... weight = weights.get_tensor(f"{prefix}.weight") except Exception: # ...otherwise they are quantized. weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 elif weights.process_group.size() > 1: try: weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True except AssertionError: # If the vocab size is not divisible by number of shards # just load the entire thing. weight = weights.get_tensor(f"{prefix}.weight") should_gather = False else: weight = weights.get_tensor(f"{prefix}.weight") should_gather = False return TensorParallelHead( get_linear(weight, bias=None), process_group=weights.process_group, should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: if not self.should_gather: return super().forward(input) world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] if input.shape[0] == 1: world_out = input.new_empty(1, out_dim * world_size) local_out = input.new_empty(1, out_dim) gather_input = local_out else: world_out = input.new_empty(out_dim * world_size, input.shape[0]) gather_input = input.new_empty(out_dim, input.shape[0]) local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) if SYSTEM == "ipex": ipex.distributed.all_gather_into_tensor( world_out, gather_input, group=self.process_group ) else: torch.distributed.all_gather_into_tensor( world_out, gather_input, group=self.process_group ) if input.shape[0] == 1: return world_out return world_out.T output = super().forward(input) world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] if SYSTEM == "ipex": ipex.distributed.all_gather(world_output, output, group=self.process_group) else: torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output class TensorParallelColumnLinear(SuperLayer): @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_gate_up(prefix) if bias: raise NotImplementedError("packed_gate_up only implemented without bias") else: bias = None linear = get_linear(weight, bias) return cls(linear) @classmethod def load_qkv( cls, config, prefix: str, weights, bias: bool, num_heads: int, num_key_value_heads: int, ): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv( prefix, num_heads=num_heads, num_key_value_heads=num_key_value_heads, ) if bias: raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None linear = get_linear(weight, bias) return cls(linear) @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_weights_col(prefix) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None linear = get_linear(weight, bias) return cls(linear) @classmethod def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): if config.quantize == "exl2": linears = [] for prefix in prefixes: weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None linears.append(get_linear(weight, b)) linear = LayerConcat(linears) else: weight = weights.get_multi_weights_col(prefixes, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) else: bias = None linear = get_linear(weight, bias) return cls(linear) class TensorParallelRowLinear(SuperLayer): def __init__(self, linear, process_group): super().__init__(linear) self.process_group = process_group @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return cls( get_linear(weight, bias), process_group=weights.process_group, ) def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) return out class TensorParallelEmbedding(torch.nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group world_size = process_group.size() rank = process_group.rank() block_size = (num_embeddings + world_size - 1) // world_size self.min_id = rank * block_size self.max_id = min(num_embeddings, (rank + 1) * block_size) self.null_idx = weight.shape[ 0 ] # Usually block_size, might be less in non even vocab_size. self.process_group = weights.process_group self.reduce = reduce """Additional 0 entry used for masking""" self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1))) def forward(self, input: torch.Tensor) -> torch.Tensor: # default all out of bounds values to `self.null_idx` that will then be mapped to 0 # translate for [0, self.max_id - self.min_id[ input = torch.where( (self.min_id > input) | (input >= self.max_id), self.null_idx, input - self.min_id, ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) return out