Dummy changes.

This commit is contained in:
Nicolas Patry 2024-04-22 14:10:30 +00:00
parent b564adc057
commit af08e359af

View File

@ -664,27 +664,27 @@ class TensorParallelHead(SuperLayer):
return super().forward(input)
world_size = self.process_group.size()
# if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
# out_dim = self.linear.weight.shape[0]
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
# if input.shape[0] == 1:
# world_out = input.new_empty(1, out_dim * world_size)
# local_out = input.new_empty(1, out_dim)
# gather_input = local_out
# else:
# world_out = input.new_empty(out_dim * world_size, input.shape[0])
# gather_input = input.new_empty(out_dim, input.shape[0])
# local_out = gather_input.T
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
# torch.mm(input, self.linear.weight.T, out=local_out)
torch.mm(input, self.linear.weight.T, out=local_out)
# torch.distributed.all_gather_into_tensor(
# world_out, gather_input, group=self.process_group
# )
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
# if input.shape[0] == 1:
# return world_out
# return world_out.T
if input.shape[0] == 1:
return world_out
return world_out.T
output = super().forward(input)
world_output = [
@ -943,7 +943,6 @@ try:
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(torch.float16, inv_freq.device, seqlen=4096)
def forward(
self,
@ -1087,6 +1086,8 @@ try:
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch.float32
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.