diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 2a55898d..9cf5c80f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -664,27 +664,27 @@ class TensorParallelHead(SuperLayer): return super().forward(input) world_size = self.process_group.size() - # if len(input.shape) == 2 and isinstance(self.linear, FastLinear): - # out_dim = self.linear.weight.shape[0] + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] - # if input.shape[0] == 1: - # world_out = input.new_empty(1, out_dim * world_size) - # local_out = input.new_empty(1, out_dim) - # gather_input = local_out - # else: - # world_out = input.new_empty(out_dim * world_size, input.shape[0]) - # gather_input = input.new_empty(out_dim, input.shape[0]) - # local_out = gather_input.T + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T - # torch.mm(input, self.linear.weight.T, out=local_out) + torch.mm(input, self.linear.weight.T, out=local_out) - # torch.distributed.all_gather_into_tensor( - # world_out, gather_input, group=self.process_group - # ) + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) - # if input.shape[0] == 1: - # return world_out - # return world_out.T + if input.shape[0] == 1: + return world_out + return world_out.T output = super().forward(input) world_output = [ @@ -943,7 +943,6 @@ try: self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None - self._update_cos_sin_cache(torch.float16, inv_freq.device, seqlen=4096) def forward( self, @@ -1087,6 +1086,8 @@ try: # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. dtype = torch.float32 + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.