diff --git a/my_optims/nccl_test/CMakeLists.txt b/my_optims/nccl_test/CMakeLists.txt new file mode 100644 index 00000000..722ffc34 --- /dev/null +++ b/my_optims/nccl_test/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(test_nccl test_nccl.cc) +target_link_libraries(test_nccl PUBLIC -lcublas -lcublasLt -lcudart + nvtx_utils mpi_utils nccl_utils memory_utils) \ No newline at end of file diff --git a/my_optims/nccl_test/my_custom_comm.cc b/my_optims/nccl_test/my_custom_comm.cc new file mode 100644 index 00000000..48c65f16 --- /dev/null +++ b/my_optims/nccl_test/my_custom_comm.cc @@ -0,0 +1,171 @@ +#include +#include +#include +#include +#include +#include +#include + +struct NcclParam { + int rank_{0}; + int world_size_{1}; + ncclUniqueId nccl_uid_; + ncclComm_t nccl_comm_ = nullptr; + cudaStream_t stream_; + + NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){}; + NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){}; + NcclParam(NcclParam const& param): + rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){}; +}; + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +std::tuple init_nccl(int tensor_para_size, int pipeline_para_size) +{ + // int argc = 0; + // char** argv = nullptr; + + // // 初始化 MPI + // MPI_Init(&argc, &argv); + int rank, world_size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank + MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数 + // printf("rank:%d, world_size:%d\n", rank, world_size); + + // 设定gpu + int device, device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(rank % device_count); + cudaGetDevice(&device); + struct cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + + int mpi_initialized; + MPI_Initialized(&mpi_initialized); + + static NcclParam tensor_para; + static NcclParam pipeline_para; + // Convert WORLD communicator into 2D grid (k * n) communicator. + // row = a tensor parallel group, col = a pipeline parallel group. + MPI_Comm grid_comm, tp_comm, pp_comm; + + int dims[2] = {pipeline_para_size, tensor_para_size}; + int periods[2] = {0, 0}; + MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm); + + // Split 2D communicator into rows and cols. + int tp_remain_dims[2] = {false, true}; + int pp_remain_dims[2] = {true, false}; + MPI_Cart_sub(grid_comm, tp_remain_dims, &tp_comm); + MPI_Cart_sub(grid_comm, pp_remain_dims, &pp_comm); + + int tp_rank, pp_rank; + MPI_Comm_rank(tp_comm, &tp_rank); + MPI_Comm_rank(pp_comm, &pp_rank); + printf("tp_rank:%d, pp_rank:%d\n", tp_rank, pp_rank); + + ncclUniqueId tp_uid; + ncclUniqueId pp_uid; + // The root of each group creates a nccl uid. + if (tp_rank == 0) { + NCCLCHECK(ncclGetUniqueId(&tp_uid)); + } + if (pp_rank == 0) { + NCCLCHECK(ncclGetUniqueId(&pp_uid)); + } + // Broadcast nccl uid to share the same nccl uid across gpus in the same group. + MPI_Bcast(&tp_uid, sizeof(tp_uid), MPI_BYTE, 0, tp_comm); + MPI_Bcast(&pp_uid, sizeof(pp_uid), MPI_BYTE, 0, pp_comm); + + ncclComm_t tp_nccl_comm, pp_nccl_comm; + NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank)); + NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank)); + + tensor_para.world_size_ = tensor_para_size; + tensor_para.rank_ = tp_rank; + tensor_para.nccl_uid_ = tp_uid; + tensor_para.nccl_comm_ = tp_nccl_comm; + cudaStreamCreate(&tensor_para.stream_); + pipeline_para.world_size_ = pipeline_para_size; + pipeline_para.rank_ = pp_rank; + pipeline_para.nccl_uid_ = pp_uid; + pipeline_para.nccl_comm_ = pp_nccl_comm; + + NcclParam* tensor_para_ptr = &tensor_para; + NcclParam* pipeline_para_ptr = &pipeline_para; + + return std::make_tuple(tensor_para_ptr, pipeline_para_ptr); +} + +void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr) +{ + // 销毁nccl + NcclParam tensor_para = *tensor_para_ptr; + NcclParam pipeline_para = *pipeline_para_ptr; + if (tensor_para.nccl_comm_ != nullptr) { + ncclCommDestroy(tensor_para.nccl_comm_); + } + if (pipeline_para.nccl_comm_ != nullptr) { + ncclCommDestroy(pipeline_para.nccl_comm_); + } + // MPI_Finalize(); +} + +ncclDataType_t getNcclDataType(torch::ScalarType torch_type) +{ + ncclDataType_t nccl_type; + if (torch_type == torch::kFloat16) { + nccl_type = ncclHalf; + } + else if (torch_type == torch::kFloat32) { + nccl_type = ncclFloat; + } + else if (torch_type == torch::kFloat64) { + nccl_type = ncclDouble; + } + else if (torch_type == torch::kInt32) { + nccl_type = ncclInt32; + } + else if (torch_type == torch::kInt64) { + nccl_type = ncclInt64; + } + else if (torch_type == torch::kInt8) { + nccl_type = ncclInt8; + } + else { + printf("[ERROR] NCCL only support float, half, int \n"); + exit(-1); + } + + return nccl_type; +} +void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr) +{ + void* data_ptr = tensor.data_ptr(); + size_t count = tensor.numel(); + torch::ScalarType torch_type = tensor.scalar_type(); + + NcclParam nccl_param = *nccl_param_ptr; + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + ncclDataType_t nccl_data_type = getNcclDataType(torch_type); + NCCLCHECK(ncclGroupStart()); + NCCLCHECK(ncclAllReduce( + (const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_)); + NCCLCHECK(ncclGroupEnd()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + py::class_(m, "NcclParam").def(py::init<>()); + m.def("init_nccl", &init_nccl, py::return_value_policy::reference, ""); + m.def("finalize_nccl", &finalize_nccl, ""); + m.def("custom_allreduce", &custom_allreduce, ""); +} \ No newline at end of file diff --git a/my_optims/nccl_test/setup.py b/my_optims/nccl_test/setup.py new file mode 100644 index 00000000..2d385967 --- /dev/null +++ b/my_optims/nccl_test/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup(name='my_custom_comm', + include_dirs=["/usr/local/cuda/targets/x86_64-linux/include/", + ], + ext_modules=[CUDAExtension('my_custom_comm', + ['my_custom_comm.cc'], + libraries=["mpi"] + ),], + cmdclass={'build_ext': BuildExtension}, + ) diff --git a/my_optims/nccl_test/test_extension.py b/my_optims/nccl_test/test_extension.py new file mode 100644 index 00000000..277f9531 --- /dev/null +++ b/my_optims/nccl_test/test_extension.py @@ -0,0 +1,22 @@ +import torch +import my_custom_comm +from mpi4py import MPI + +COMM = MPI.COMM_WORLD +rank = COMM.Get_rank() +world_size = COMM.Get_size() +tp_ptr, pp_ptr = my_custom_comm.init_nccl(2, 1) +print(tp_ptr) +device = rank % torch.cuda.device_count() +torch.cuda.set_device(device) +torch.cuda.set_per_process_memory_fraction(1., device) +print(rank, world_size) + +t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]], + dtype=torch.float16).to('cuda') + +print(my_custom_comm.custom_allreduce(t, tp_ptr)) +print(t) + +torch.cuda.synchronize() +my_custom_comm.finalize_nccl(tp_ptr, pp_ptr) \ No newline at end of file diff --git a/my_optims/nccl_test/test_extension_torch.py b/my_optims/nccl_test/test_extension_torch.py new file mode 100644 index 00000000..909bdcc0 --- /dev/null +++ b/my_optims/nccl_test/test_extension_torch.py @@ -0,0 +1,32 @@ +import torch +import torch.distributed as dist +import my_custom_comm +import os +assert dist.is_mpi_available() + +# rank, world_size = int(os.getenv('RANK')), int(os.getenv('WORLD_SIZE')) +print(os.environ) +dist.init_process_group(backend=dist.Backend.MPI, + # rank=rank, + # world_size=2 + ) +assert dist.is_initialized() +rank = dist.get_rank() +world_size = dist.get_world_size() +print(f'{rank=},{world_size=}') + +# tp_ptr, pp_ptr = my_custom_comm.init_nccl(2, 1) +# print(tp_ptr) +# device = rank % torch.cuda.device_count() +# torch.cuda.set_device(device) +# torch.cuda.set_per_process_memory_fraction(1., device) +# print(rank, world_size) + +# t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]], +# dtype=torch.float16).to('cuda') + +# print(my_custom_comm.custom_allreduce(t, tp_ptr)) +# print(t) + +# torch.cuda.synchronize() +# my_custom_comm.finalize_nccl(tp_ptr, pp_ptr) \ No newline at end of file diff --git a/my_optims/nccl_test/test_nccl.cc b/my_optims/nccl_test/test_nccl.cc new file mode 100644 index 00000000..246a9eb9 --- /dev/null +++ b/my_optims/nccl_test/test_nccl.cc @@ -0,0 +1,88 @@ +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include +#include + + +using namespace fastertransformer; + +template +void test_nccl(); + +int main(int argc, char **argv){ + mpi::initialize(&argc, &argv); + + test_nccl(); + + mpi::finalize(); + + return 0; +} + + +template +void test_nccl() +{ + // int rank = 0; + // int world_size = 1; + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); + if (rank == 0) { + printf("Total ranks: %d.\n", world_size); + } + int device, device_count; + check_cuda_error(cudaGetDeviceCount(&device_count)); + check_cuda_error(cudaSetDevice(rank % device_count)); + check_cuda_error(cudaGetDevice(&device)); + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, device)); + printf("Device %s\n", prop.name); + printf("P%d is running with GPU #%d.\n", rank, device); + print_mem_usage(); + + // 声明数据指针 + std::vector weights_ptr = std::vector(1); + size_t shape = 4096 * 2048; + deviceMalloc(&weights_ptr[0], shape, true); // 从gpu中开辟空间,并随即初始化 + // deviceFill(weights_ptr[0], (size_t)shape, (T)2.0); // 给gpu空间赋值 + + // 初始化nccl + int tensor_para_size = 2; + int pipeline_para_size = 1; + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); + std::cout << "tensor_para info:" << tensor_para.rank_ << std::endl; + + cudaStream_t stream; + cudaStreamCreate(&stream); + + mpi::barrier(); + cudaProfilerStart(); + ft_nvtx::setScope("run_time"); + PUSH_RANGE("run time") + for (int i = 0; i < 32; i++) { + cudaDeviceSynchronize(); + ftNcclAllReduceSum(weights_ptr[0], weights_ptr[0], shape, tensor_para, stream); + cudaDeviceSynchronize(); + deviceMalloc(&weights_ptr[0], shape, true); + } + + mpi::barrier(); + POP_RANGE; + ft_nvtx::resetScope(); + + // T* hBuf = new T[shape]; // 开辟CPU空间 + // cudaD2Hcpy(hBuf, weights_ptr[0], shape); + // { // 从cpu打印,查看数据 + // for (size_t i = 0; i < shape; i++) { + // printf("%f ", (float)hBuf[i]); + // } + // std::cout << std::endl; + // } + // delete[] hBuf; + + return; +} \ No newline at end of file diff --git a/my_optims/nccl_test/torch_nccl.py b/my_optims/nccl_test/torch_nccl.py new file mode 100644 index 00000000..3f138f0d --- /dev/null +++ b/my_optims/nccl_test/torch_nccl.py @@ -0,0 +1,30 @@ +import torch +import torch.distributed as dist +import os +import time + + +def test_nccl(): + rank, world_size = int(os.getenv('RANK')), int(os.getenv('WORLD_SIZE')) + dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) + process_group = dist.group.WORLD + torch.cuda.set_device(rank % world_size) + shape = 4096 * 2048 + weight = torch.randn([shape], dtype=torch.float16).to("cuda") + + # nccl test + dist.barrier(process_group) + for i in range(32): + torch.cuda.synchronize() + dist.all_reduce(weight, group=process_group) + torch.cuda.synchronize() + weight = torch.randn([shape], dtype=torch.float16).to("cuda") + + dist.barrier(process_group) + + + +if __name__ == '__main__': + test_nccl() \ No newline at end of file diff --git a/my_optims/optims/test_llama_fa.py b/my_optims/optims/test_llama_fa.py new file mode 100644 index 00000000..f643d043 --- /dev/null +++ b/my_optims/optims/test_llama_fa.py @@ -0,0 +1,709 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- test_llama_fa.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/18-20:50 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import math +import time +from pathlib import Path +from typing import List, Dict, Optional +from typing import Tuple + +# Flash attention imports +import dropout_layer_norm +import flash_attn_2_cuda +import numpy as np +import rotary_emb +import torch +import torch.distributed +import transformers +from safetensors import safe_open +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from vllm import attention_ops, cache_ops + +# vllm imports + +BLOCK_SIZE = 16 + + +class FastLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + @staticmethod + def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = cls._create_inv_freq(dim, base, device) + scaling_factor = None + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim: 2 * rotary_dim] + + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x + + +class Weights: + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + aliases: Optional[Dict[str, List[str]]] = None, + ): + routing = {} + for filename in filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self._handles = {} + + def _get_handle(self, filename): + if filename not in self._handles: + f = safe_open(filename, framework="pytorch") + self._handles[filename] = f + + return self._handles[filename] + + def get_filename(self, tensor_name: str) -> (str, str): + filename = self.routing.get(tensor_name, None) + if filename is None: + aliases = self.aliases.get(tensor_name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias + raise RuntimeError(f"weight {tensor_name} does not exist") + return str(filename), tensor_name + + def get_tensor(self, tensor_name: str, to_device=True): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype not in [torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) + if to_device: + tensor = tensor.to(device=self.device) + return tensor + + def load_multi_linear(self, config, prefixes): + weight = torch.cat([self.get_tensor(f"{p}.weight") for p in prefixes], dim=0) + return FastLinear(weight, bias=None) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size ** -0.5 + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + # q,k,v,o and rotary + self.query_key_value = weights.load_multi_linear(config, + [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]) + self.o_proj = FastLinear.load(config, f"{prefix}.o_proj", weights, bias=False) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_size, base=config.rope_theta, device=weights.device + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn_2_cuda.varlen_fwd( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + cu_seqlen_prefill, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + -1, + 0, + False, + None, + ) + # Decode + else: + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + attention_ops.paged_attention_v1( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class LlamaMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.intermediate_size = config.intermediate_size + self.act = ACT2FN[act] + # Fuse gate and up proj + self.gate_up_proj = weights.load_multi_linear(config, [f"{prefix}.gate_proj", f"{prefix}.up_proj"]) + self.down_proj = FastLinear.load(config, f"{prefix}.down_proj", weights, bias=False) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashLlamaLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + + self.input_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.self_attn = FlashLlamaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.post_attention_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + embeddings = weights.get_tensor(f"model.embed_tokens.weight") + self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)), + padding_idx=config.pad_token_id) + self.layers = nn.ModuleList( + [ + FlashLlamaLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + # for layer_id in range(1) + ] + ) + self.norm = LlamaRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.config = config + self.model = FlashLlamaModel(config, weights) + self.lm_head = FastLinear.load(config, prefix="lm_head", weights=weights, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks + self.device = device + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate(self, blocks, max_blocks, needed_blocks_slots): + """ + blocks: 总共需要的blocks数量 + max_blocks: 最大的blocks数量大小 + needed_blocks_slots: 每个序列所需的blocks及其对应的序列长度 + """ + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= blocks + ), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[: blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(needed_blocks_slots), max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks: cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + return block_tables, block_tables_tensor.to(self.device), torch.concat(slots).to(self.device) + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + + +def generate(tokenizer, model, prompt, max_new_tokens=10): + input_ids = tokenizer(prompt).input_ids + + def warmup(): + print("start warmup...") + global CACHE_MANAGER + blocks = 260 + CACHE_MANAGER = CacheManager(blocks, + model.config.num_hidden_layers, + model.config.num_key_value_heads, + model.config.hidden_size // model.config.num_attention_heads, + torch.float16, + device) + input_length = 1024 + bs = 4 + warmup_inputs = { + 'input_ids': torch.arange(1, input_length + 1, dtype=torch.int64, device=device).repeat(bs), + 'position_ids': torch.arange(0, input_length, dtype=torch.int32, device=device).repeat(bs), + 'cu_seqlen_prefill': torch.tensor([i * input_length for i in range(bs + 1)], dtype=torch.int32, + device=device), + 'block_tables': torch.arange(0, blocks, dtype=torch.int32, device=device).split(blocks // bs), + 'slots': torch.arange(0, 4144, dtype=torch.int32, device=device), + 'input_lengths': torch.tensor([input_length] * 4, dtype=torch.int32, device=device), + 'max_s': 1024, + 'lm_head_indices': None + } + model.forward(**warmup_inputs, kv_cache=CACHE_MANAGER.kv_cache) + + del CACHE_MANAGER + torch.cuda.empty_cache() + + # 预热 + warmup() + + print("start speed test running") + # 申请缓存空间 + global CACHE_MANAGER + CACHE_MANAGER = CacheManager(100, + model.config.num_hidden_layers, + model.config.num_key_value_heads, + model.config.hidden_size // model.config.num_attention_heads, + torch.float16, + device) + total_tokens = len(input_ids) + max_new_tokens - 1 + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks_slots = [(needed_blocks, total_tokens)] + _, block_tables_tensor, slots = CACHE_MANAGER.allocate(needed_blocks, needed_blocks, needed_blocks_slots) + # forward循环 + loops = 10 + for loop in range(loops): + print(f"loop {loop}...") + times = [] + new_tokens = [] + for step in range(max_new_tokens): + if step == 0: + # prefill step + slot_indices = torch.arange(0, 0 + len(input_ids), dtype=torch.int64) + inputs = { + 'input_ids': torch.tensor(input_ids, dtype=torch.int64, device=device), + 'position_ids': torch.arange(0, len(input_ids), dtype=torch.int32, device=device), + 'cu_seqlen_prefill': torch.tensor([0, len(input_ids)], dtype=torch.int32, device=device), + 'block_tables': block_tables_tensor, + 'slots': slots[slot_indices], + 'input_lengths': torch.tensor([len(input_ids)], dtype=torch.int32, device=device), + 'max_s': len(input_ids), + 'lm_head_indices': torch.tensor([0 + len(input_ids) - 1], dtype=torch.int32, device=device) + } + else: + # incremental step + current_length = len(input_ids) + step + inputs = { + 'input_ids': new_tokens[-1], + 'position_ids': torch.tensor([current_length - 1], dtype=torch.int32, device=device), + 'cu_seqlen_prefill': None, + 'block_tables': block_tables_tensor, + 'slots': torch.tensor([current_length - 1], dtype=torch.int32, device=device), + 'input_lengths': torch.tensor([current_length], dtype=torch.int32, device=device), + 'max_s': current_length, + 'lm_head_indices': None + } + torch.cuda.synchronize() + s_time = time.time() + logits = model.forward(**inputs, kv_cache=CACHE_MANAGER.kv_cache) + torch.cuda.synchronize() + cost_time = time.time() - s_time + next_token_id = logits.argmax(dim=-1) + new_tokens.append(next_token_id) + times.append(round(cost_time, 6)) + + if loop == 0: + new_tokens = torch.concat(new_tokens) + print(tokenizer.decode(new_tokens, skip_special_tokens=True)) + + elapsed_time = np.mean(times) + print(f"total new tokens: {max_new_tokens}, cost time: {sum(times):.6f} s\n" + f"time_per_token: {elapsed_time * 1000:.3f} ms, tps: {1 / elapsed_time:.2f} tokens/s") + + +def main(model_path): + # step 0: 定义路径与属性 + model_path = Path(model_path) + config = transformers.AutoConfig.from_pretrained(model_path) + model_files = list(model_path.glob('*.safetensors')) + + # step 1: 定义tokenizer与权重 + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="left", truncation_side="left") + weights = Weights(model_files, device, torch.float16, process_group=None) + + # step2: 定义模型 + model = FlashLlamaForCausalLM(config, weights).eval() + print(model) + + # step3: 推理 + with torch.no_grad(): + prompt = "who are you?" + generate(tokenizer, model, prompt, max_new_tokens=100) + + +if __name__ == '__main__': + CACHE_MANAGER: Optional[CacheManager] = None + device = torch.device("cuda") + main('/code/models/llama-7b-hf') diff --git a/my_optims/optims/test_llama_hf.py b/my_optims/optims/test_llama_hf.py new file mode 100644 index 00000000..d1f96570 --- /dev/null +++ b/my_optims/optims/test_llama_hf.py @@ -0,0 +1,42 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: aime-local-inference-server -- test_llama_hf.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/7-15:19 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import time + +import torch +import transformers + + +def test_llama(): + tokenizer = transformers.AutoTokenizer.from_pretrained("/code/models/llama-7b-hf") + model = transformers.AutoModelForCausalLM.from_pretrained("/code/models/llama-7b-hf", torch_dtype=torch.float16, + device_map="auto").half().eval() + prompt = "who are you?" + input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda') + # warm up + output_ids = model.generate(input_ids, max_new_tokens=100) + i_l = len(input_ids[0]) + o_l = len(output_ids[0]) + print(f'input length: {i_l}\t' + f'output length: {o_l}') + print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) + # tps + loop = 10 + s_time = time.time() + for _ in range(loop): + model.generate(input_ids, max_new_tokens=100) + mean_ = (time.time() - s_time) / loop + + print(f"tps: {(o_l - i_l) / mean_:.4f}") + + +if __name__ == '__main__': + test_llama() diff --git a/my_optims/optims/test_tp/flash_llama_modeling.py b/my_optims/optims/test_tp/flash_llama_modeling.py new file mode 100644 index 00000000..080b2171 --- /dev/null +++ b/my_optims/optims/test_tp/flash_llama_modeling.py @@ -0,0 +1,467 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Tuple + +# Flash attention imports +import dropout_layer_norm +import flash_attn_2_cuda +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +# vllm imports +from vllm import attention_ops, cache_ops +from torch.nn import functional as F + +from layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, +) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize != "gptq": + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + # self.rotary_emb = PositionRotaryEmbedding.load( + # config=config, prefix=f"{prefix}.rotary_emb", weights=weights + # ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_size, base=config.rope_theta, device=weights.device + ) + + self.softmax_scale = self.head_size ** -0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn_2_cuda.varlen_fwd( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + cu_seqlen_prefill, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + -1, + 0, + False, + None, + ) + # Decode + else: + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + attention_ops.paged_attention_v1( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class LlamaMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashLlamaLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashLlamaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + # self.embed_tokens = TensorParallelEmbedding( + # prefix="model.embed_tokens", weights=weights + # ) + embeddings = weights.get_tensor(f"model.embed_tokens.weight") + self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)), + padding_idx=config.pad_token_id) + + self.layers = nn.ModuleList( + [ + FlashLlamaLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + # for layer_id in range(1) + ] + ) + self.norm = LlamaRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashLlamaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/my_optims/optims/test_tp/layers.py b/my_optims/optims/test_tp/layers.py new file mode 100644 index 00000000..828bda9c --- /dev/null +++ b/my_optims/optims/test_tp/layers.py @@ -0,0 +1,399 @@ +import os +from typing import List + +import torch +import torch.distributed +from torch import nn +from torch.nn import functional as F + + +class FastLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +def get_linear(weight, bias, quantize): + linear = FastLinear(weight, bias) + return linear + + +class SuperLayer(nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group, should_gather: bool): + super().__init__(linear) + self.process_group = process_group + self.should_gather = should_gather + + @staticmethod + def load(config, prefix: str, weights): + if weights.process_group.size() > 1: + try: + assert False + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + + # GPTQ doesn't quantize heads (nor embeddings) + if config.quantize == "gptq": + quantize = None + else: + quantize = config.quantize + return TensorParallelHead( + get_linear(weight, bias=None, quantize=quantize), + process_group=weights.process_group, + should_gather=should_gather, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.should_gather: + return super().forward(input) + world_size = self.process_group.size() + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + + output = super().forward(input) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + return cls.load_multi(config, [prefix], weights, bias, dim=0) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super().forward(input) + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class TensorParallelEmbedding(nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = num_embeddings // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = block_size + self.process_group = weights.process_group + self.reduce = reduce + + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = torch.nn.functional.embedding(input, self.weight) + if self.reduce and self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +try: + import dropout_layer_norm + + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +except ImportError: + pass + +try: + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb + + + def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + + def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} + return rope_scaling + return getattr(config, "rope_scaling", None) + + + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, device=inv_freq.device, + scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) + + @classmethod + def load(cls, config, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, device=inv_freq.device, + scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim: 2 * rotary_dim] + + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x + + + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - ( + self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +except ImportError: + pass diff --git a/my_optims/optims/test_tp/test_llama_fa_tp.py b/my_optims/optims/test_tp/test_llama_fa_tp.py new file mode 100644 index 00000000..b0d84f14 --- /dev/null +++ b/my_optims/optims/test_tp/test_llama_fa_tp.py @@ -0,0 +1,247 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- test_llama_fa.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/18-20:50 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import math +import time +from pathlib import Path +from typing import List, Optional + +# Flash attention imports +import numpy as np +import torch +import torch.distributed +import transformers + +from flash_llama_modeling import FlashLlamaForCausalLM +from utils import initialize_torch_distributed +from weights import Weights + +BLOCK_SIZE = 16 + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks + self.device = device + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate(self, blocks, max_blocks, needed_blocks_slots): + """ + blocks: 总共需要的blocks数量 + max_blocks: 最大的blocks数量大小 + needed_blocks_slots: 每个序列所需的blocks及其对应的序列长度 + """ + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= blocks + ), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[: blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(needed_blocks_slots), max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks: cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + return block_tables, block_tables_tensor.to(self.device), torch.concat(slots).to(self.device) + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + + +def generate(tokenizer, model, config, device, prompt, max_new_tokens=10): + input_ids = tokenizer(prompt).input_ids + + def warmup(): + print("start warmup...") + global CACHE_MANAGER + blocks = 260 + CACHE_MANAGER = CacheManager(blocks, + len(model.model.layers), + model.model.num_key_value_heads, + model.model.head_size, + torch.float16, + device) + input_length = 1024 + bs = 4 + warmup_inputs = { + 'input_ids': torch.arange(1, input_length + 1, dtype=torch.int64, device=device).repeat(bs), + 'position_ids': torch.arange(0, input_length, dtype=torch.int32, device=device).repeat(bs), + 'cu_seqlen_prefill': torch.tensor([i * input_length for i in range(bs + 1)], dtype=torch.int32, + device=device), + 'block_tables': torch.arange(0, blocks, dtype=torch.int32, device=device).split(blocks // bs), + 'slots': torch.arange(0, 4144, dtype=torch.int32, device=device), + 'input_lengths': torch.tensor([input_length] * 4, dtype=torch.int32, device=device), + 'max_s': 1024, + 'lm_head_indices': None + } + model.forward(**warmup_inputs, kv_cache=CACHE_MANAGER.kv_cache) + + del CACHE_MANAGER + torch.cuda.empty_cache() + + # 预热 + warmup() + + print("start speed test running") + # 申请缓存空间 + global CACHE_MANAGER + CACHE_MANAGER = CacheManager(100, + len(model.model.layers), + model.model.num_key_value_heads, + model.model.head_size, + torch.float16, + device) + total_tokens = len(input_ids) + max_new_tokens - 1 + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks_slots = [(needed_blocks, total_tokens)] + _, block_tables_tensor, slots = CACHE_MANAGER.allocate(needed_blocks, needed_blocks, needed_blocks_slots) + # forward循环 + loops = 10 + tpss = [] + for loop in range(loops): + print(f"loop {loop}...") + times = [] + new_tokens = [] + for step in range(max_new_tokens): + if step == 0: + # prefill step + slot_indices = torch.arange(0, 0 + len(input_ids), dtype=torch.int64) + inputs = { + 'input_ids': torch.tensor(input_ids, dtype=torch.int64, device=device), + 'position_ids': torch.arange(0, len(input_ids), dtype=torch.int32, device=device), + 'cu_seqlen_prefill': torch.tensor([0, len(input_ids)], dtype=torch.int32, device=device), + 'block_tables': block_tables_tensor, + 'slots': slots[slot_indices], + 'input_lengths': torch.tensor([len(input_ids)], dtype=torch.int32, device=device), + 'max_s': len(input_ids), + 'lm_head_indices': torch.tensor([0 + len(input_ids) - 1], dtype=torch.int32, device=device) + } + else: + # incremental step + current_length = len(input_ids) + step + inputs = { + 'input_ids': new_tokens[-1], + 'position_ids': torch.tensor([current_length - 1], dtype=torch.int32, device=device), + 'cu_seqlen_prefill': None, + 'block_tables': block_tables_tensor, + 'slots': torch.tensor([current_length - 1], dtype=torch.int32, device=device), + 'input_lengths': torch.tensor([current_length], dtype=torch.int32, device=device), + 'max_s': current_length, + 'lm_head_indices': None + } + torch.cuda.synchronize() + s_time = time.time() + logits = model.forward(**inputs, kv_cache=CACHE_MANAGER.kv_cache) + torch.cuda.synchronize() + cost_time = time.time() - s_time + next_token_id = logits.argmax(dim=-1) + new_tokens.append(next_token_id) + times.append(round(cost_time, 6)) + + if loop == 0: + new_tokens = torch.concat(new_tokens) + print(tokenizer.decode(new_tokens, skip_special_tokens=True)) + + elapsed_time = np.mean(times) + tps = 1 / elapsed_time + tpss.append(tps) + print(times) + print(f"total new tokens: {max_new_tokens}, cost time: {sum(times):.6f} s\n" + f"time_per_token: {elapsed_time * 1000:.3f} ms, tps: {tps:.2f} tokens/s") + print(f'mean tps: {np.mean(tpss):.2f} tokens/s') + + +def main(model_path): + # init env + process_group, rank, world_size = initialize_torch_distributed() + # step 0: 定义路径与属性 + model_path = Path(model_path) + config = transformers.AutoConfig.from_pretrained(model_path) + config.quantize = None + model_files = list(model_path.glob('*.safetensors')) + + # step 1: 定义tokenizer与权重 + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="left", truncation_side="left") + device = torch.device(f"cuda:{rank}") + weights = Weights(model_files, device, torch.float16, process_group=process_group) + + # step2: 定义模型 + torch.distributed.barrier(group=process_group) + model = FlashLlamaForCausalLM(config, weights).eval() + torch.distributed.barrier(group=process_group) + print(model) + + # step3: 推理 + with torch.no_grad(): + prompt = "who are you?" + generate(tokenizer, model, config, device, prompt, max_new_tokens=100) + + +if __name__ == '__main__': + CACHE_MANAGER: Optional[CacheManager] = None + main('/code/models/llama-7b-hf') diff --git a/my_optims/optims/test_tp/utils.py b/my_optims/optims/test_tp/utils.py new file mode 100644 index 00000000..bcdb46bb --- /dev/null +++ b/my_optims/optims/test_tp/utils.py @@ -0,0 +1,100 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- utils.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/27-14:43 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import os +import time +from datetime import timedelta + +import torch +from loguru import logger +from torch.distributed import ProcessGroupNCCL + +RANK = int(os.getenv("LOCAL_RANK", "0")) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", "2")) +NCCL_PORT = int(os.getenv("NCCL_PORT", "29500")) +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + + +class FakeBarrier: + def wait(self): + pass + + +class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def allreduce(self, *args, **kwargs): + return FakeBarrier() + + def allgather(self, inputs, local_tensor, **kwargs): + assert ( + len(inputs[0]) == len(local_tensor) == 1 + ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + for input_ in inputs: + input_[0].data = local_tensor[0].data + return FakeBarrier() + + def barrier(self, *args, **kwargs): + return FakeBarrier() + + def size(self): + return self._size + + def rank(self): + return self._rank + + +def initialize_torch_distributed(): + # Set the device id. + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() + torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + backend = "nccl" + options = ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + options._timeout = timedelta(seconds=60) + if not torch.distributed.is_initialized(): + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + init_method=f"tcp://localhost:{NCCL_PORT}", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + logger.info(f"torch.distributed is initialized on rank {RANK} of {WORLD_SIZE}.") + else: + logger.warning("torch.distributed is already initialized.") + + return torch.distributed.group.WORLD, RANK, WORLD_SIZE + + +class Timer: + def __init__(self): + self.times = [] + self.time = None + + def start(self): + torch.cuda.synchronize() + self.time = time.time() + + def end(self): + torch.cuda.synchronize() + self.times.append(time.time() - self.time) + + @property + def elapsed(self): + self.times.pop(0) + return round(sum(self.times) / len(self.times) * 1000, 2) diff --git a/my_optims/optims/test_tp/weights.py b/my_optims/optims/test_tp/weights.py new file mode 100644 index 00000000..dbd4d45d --- /dev/null +++ b/my_optims/optims/test_tp/weights.py @@ -0,0 +1,336 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- weights.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/27-15:05 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import json +import os +from pathlib import Path +from typing import List, Dict, Optional, Tuple + +import torch +from huggingface_hub import hf_hub_download +from loguru import logger +from safetensors import safe_open, SafetensorError + + +class Weights: + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + aliases: Optional[Dict[str, List[str]]] = None, + ): + routing = {} + for filename in filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self._handles = {} + + def _get_handle(self, filename): + if filename not in self._handles: + f = safe_open(filename, framework="pytorch") + self._handles[filename] = f + + return self._handles[filename] + + def get_filename(self, tensor_name: str) -> (str, str): + filename = self.routing.get(tensor_name, None) + if filename is None: + aliases = self.aliases.get(tensor_name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias + raise RuntimeError(f"weight {tensor_name} does not exist") + return str(filename), tensor_name + + def _get_slice(self, tensor_name: str): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + return slice_ + + def get_shape(self, tensor_name: str): + return self._get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str, to_device=True): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype not in [torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) + if to_device: + tensor = tensor.to(device=self.device) + return tensor + + def get_partial_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + if dim == 0: + tensor = slice_[start:stop] + elif dim == 1: + tensor = slice_[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype != torch.int32: + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + + def _get_qweight(self, name: str): + slice_ = self._get_slice(name) + total_size = slice_.get_shape()[1] + assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[:, start:stop] + k = slice_[:, start + single_size:stop + single_size] + v = slice_[:, start + 2 * single_size:stop + 2 * single_size] + weight = torch.cat([q, k, v], dim=1) + weight = weight.to(device=self.device) + return weight + + def get_weights_col_packed_qkv(self, prefix: str, quantize: str): + """ + Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being + already alternating Q,K,V within the main tensor + """ + if quantize == "gptq": + try: + qweight = self._get_qweight(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") + scales = scales.to(dtype=self.dtype) + g_idx = self.get_tensor(f"{prefix}.g_idx") + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + slice_ = self._get_slice(f"{prefix}.weight") + total_size = slice_.get_shape()[0] + assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[start:stop] + k = slice_[start + single_size:stop + single_size] + v = slice_[start + 2 * single_size:stop + 2 * single_size] + weight = torch.cat([q, k, v], dim=0) + weight = weight.to(device=self.device) + weight = weight.to(dtype=self.dtype) + return weight + + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): + if quantize == "gptq": + try: + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + weight = torch.cat(w, dim=dim) + return weight + + def get_tensor_shard(self, var, dim): + world_size = self.process_group.size() + rank = self.process_group.rank() + block_size = var.size()[dim] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + if dim == 0: + tensor = var[start:stop] + elif dim == 1: + tensor = var[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_multi_weights_row(self, prefix: str, quantize: str): + if quantize == "gptq": + use_exllama = True + bits, groupsize = self._get_gptq_params() + + if bits != 4: + use_exllama = False + + if self.process_group.size() > 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + logger.warning( + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" + ) + use_exllama = False + else: + logger.info("Using exllama kernels") + + if use_exllama: + if groupsize >= 0: + # Exllama reorders the weights in advance and the activations on the fly, thus + # the scales and zero-points do not need to be reordered. + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + + # For tp > 1, at this point we know we do not use act-order + if self.process_group.size() == 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + else: + # The triton kernel reorders the scales/zero points instead of the weight/activation. + # Thus, each rank needs the full qzeros/scales. + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + else: + weight = self.get_sharded(f"{prefix}.weight", dim=1) + return weight + + def _get_gptq_params(self) -> Tuple[int, int]: + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except (SafetensorError, RuntimeError) as e: + try: + bits = self.gptq_bits + groupsize = self.gptq_groupsize + except Exception: + raise e + + return bits, groupsize + + def _set_gptq_params(self, model_id): + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["quantization_config"]["bits"] + self.gptq_groupsize = data["quantization_config"]["group_size"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["bits"] + self.gptq_groupsize = data["group_size"] + except Exception: + pass diff --git a/my_optims/tgi_update/flash_llama.py b/my_optims/tgi_update/flash_llama.py new file mode 100644 index 00000000..094e0726 --- /dev/null +++ b/my_optims/tgi_update/flash_llama.py @@ -0,0 +1,94 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from transformers import AutoConfig, AutoTokenizer +from transformers.models.llama import LlamaTokenizer +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + LlamaConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +import os + +tracer = trace.get_tracer(__name__) +USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1 +if USE_CUSTOM_NCCL: + from text_generation_server.utils.my_dist import initialize_mpi_distributed + +class FlashLlama(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if USE_CUSTOM_NCCL: + self.process_group, rank, world_size, COMM = initialize_mpi_distributed() + else: + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + try: + raise + tokenizer = LlamaTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + except Exception: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = LlamaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + if USE_CUSTOM_NCCL: + COMM.barrier() + else: + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id) + + model = FlashLlamaForCausalLM(config, weights) + + if USE_CUSTOM_NCCL: + COMM.barrier() + else: + torch.distributed.barrier(group=self.process_group) + super(FlashLlama, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/my_optims/tgi_update/flash_llama_modeling.py b/my_optims/tgi_update/flash_llama_modeling.py new file mode 100644 index 00000000..c3cca1ac --- /dev/null +++ b/my_optims/tgi_update/flash_llama_modeling.py @@ -0,0 +1,513 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +# Flash attention imports +import dropout_layer_norm + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, +) + + +class LlamaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + rope_theta=10000.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + # self.rotary_emb = PositionRotaryEmbedding.load( + # config=config, prefix=f"{prefix}.rotary_emb", weights=weights + # ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class LlamaMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashLlamaLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashLlamaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + import os + if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1: + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + else: + from torch.nn import functional as F + from loguru import logger + embeddings = weights.get_tensor(f"model.embed_tokens.weight") + self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)), + padding_idx=config.pad_token_id) + logger.info("Disabled embedding tensor parallel! ") + self.layers = nn.ModuleList( + [ + FlashLlamaLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashLlamaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/my_optims/tgi_update/layers.py b/my_optims/tgi_update/layers.py new file mode 100644 index 00000000..d0667bda --- /dev/null +++ b/my_optims/tgi_update/layers.py @@ -0,0 +1,815 @@ +import os +import torch + +from torch import nn +from torch.nn import functional as F +from typing import List +from loguru import logger +from functools import lru_cache + +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params, Params4bit + +except ImportError: + HAS_BITS_AND_BYTES = False + +from accelerate import init_empty_weights + +from text_generation_server.utils.gptq.quant_linear import QuantLinear + + +HAS_AWQ = True +try: + from text_generation_server.utils.awq.quantize.qmodule import WQLinear +except ImportError: + HAS_AWQ = False + +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + from text_generation_server.utils.gptq.exllama import Ex4bitLinear + + HAS_EXLLAMA = True + except ImportError: + pass + +from typing import Optional + +HAS_EETQ = False +try: + from EETQ import quant_weights, w8_a16_gemm + + HAS_EETQ = True +except ImportError: + pass +import my_custom_comm +USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1 +USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1")) + +# Monkey patching +@classmethod +def load_layer_norm(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = nn.Parameter(weight) + ln.bias = nn.Parameter(bias) + return ln + + +@classmethod +def load_layer_norm_no_bias(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = nn.Parameter(weight) + ln.bias = None + return ln + + +@classmethod +def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = nn.Parameter(weight) + conv2d.bias = nn.Parameter(bias) + return conv2d + + +@classmethod +def load_conv2d_no_bias( + cls, prefix, weights, in_channels, out_channels, kernel_size, stride +): + weight = weights.get_tensor(f"{prefix}.weight") + with init_empty_weights(): + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) + + conv2d.weight = nn.Parameter(weight) + conv2d.bias = None + return conv2d + + +torch.nn.Conv2d.load = load_conv2d +torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias +torch.nn.LayerNorm.load = load_layer_norm +torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias + + +class FastLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +class EETQLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + device = weight.device + weight = torch.t(weight).contiguous().cpu() + weight, scale = quant_weights(weight, torch.int8, False) + + self.weight = weight.cuda(device) + self.scale = scale.cuda(device) + self.bias = bias.cuda(device) if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = w8_a16_gemm(input, self.weight, self.scale) + output = output + self.bias if self.bias is not None else output + return output + + +class Linear8bitLt(nn.Module): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__() + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self.weight.cuda(weight.device) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + + +class Linear4bit(nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, + requires_grad=False, + compress_statistics=True, + quant_type=quant_type, + ) + self.compute_dtype = None + self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out + + +@lru_cache(1) +def warn_deprecate_bnb(): + logger.warning( + "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" + ) + + +def get_linear(weight, bias, quantize): + if quantize is None: + linear = FastLinear(weight, bias) + elif quantize == "eetq": + if HAS_EETQ: + linear = EETQLinear(weight, bias) + else: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) + elif quantize == "bitsandbytes": + warn_deprecate_bnb() + linear = Linear8bitLt( + weight, + bias, + has_fp16_weights=False, + threshold=6.0, + ) + if bias is not None: + linear.bias = nn.Parameter(bias) + elif quantize == "bitsandbytes-fp4": + linear = Linear4bit( + weight, + bias, + quant_type="fp4", + ) + elif quantize == "bitsandbytes-nf4": + linear = Linear4bit( + weight, + bias, + quant_type="nf4", + ) + elif quantize == "gptq": + try: + qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `gptq` compatible, loader needs to be updated." + ) + + if use_exllama: + linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + else: + linear = QuantLinear( + qweight, + qzeros, + scales, + g_idx, + bias, + bits, + groupsize, + ) + elif quantize == "awq": + try: + qweight, qzeros, scales, _, bits, groupsize, _ = weight + except Exception: + raise NotImplementedError( + f"The passed weight is not `awq` compatible, loader needs to be updated." + ) + linear = WQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) + else: + raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") + return linear + + +class SuperLayer(nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group, should_gather: bool): + super().__init__(linear) + self.process_group = process_group + self.should_gather = should_gather + + @staticmethod + def load(config, prefix: str, weights): + if weights.process_group.size() > 1: + try: + assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1 + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + logger.info("Disabled lm head parallel! ") + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + + # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) + if config.quantize in ["gptq", "awq", "eetq"]: + quantize = None + else: + quantize = config.quantize + return TensorParallelHead( + get_linear(weight, bias=None, quantize=quantize), + process_group=weights.process_group, + should_gather=should_gather, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.should_gather: + return super().forward(input) + + world_size = self.process_group.size() + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + + output = super().forward(input) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + return cls.load_multi(config, [prefix], weights, bias, dim=0) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super().forward(input) + if self.process_group.size() > 1: + if USE_CUSTOM_NCCL: + my_custom_comm.custom_allreduce(out, self.process_group.tp_comm) + else: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class TensorParallelEmbedding(nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = num_embeddings // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = block_size + self.process_group = weights.process_group + self.reduce = reduce + + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = torch.nn.functional.embedding(input, self.weight) + if self.reduce and self.process_group.size() > 1: + if USE_CUSTOM_NCCL: + my_custom_comm.custom_allreduce(out, self.process_group.tp_comm) + else: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +try: + import dropout_layer_norm + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +except ImportError: + pass + + +try: + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb + + def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = { + "type": os.environ["ROPE_SCALING"], + "factor": float(os.environ["ROPE_FACTOR"]), + } + return rope_scaling + return getattr(config, "rope_scaling", None) + + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding( + dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling["original_max_position_embeddings"], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1 + + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + @classmethod + def load(cls, config, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) + elif rope_scaling["type"] == "yarn": + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling["original_max_position_embeddings"], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1 + + ) + else: + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x + + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq( + self.dim, newbase, self.inv_freq.device + ) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + + # Inverse dim formula to find dim based on number of rotations + import math + def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) + + # Find dim range bounds based on rotations + def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(find_correction_dim( + low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim( + high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim-1) # Clamp values just in case + + def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(scale=1): + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + inv_freq_extrapolation = _create_inv_freq( + self.dim, self.base, self.inv_freq.device + ) + freqs = 1.0 / inv_freq_extrapolation + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) + inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.inv_freq = inv_freq + self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + + + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + +except ImportError: + pass diff --git a/my_optims/tgi_update/main.rs b/my_optims/tgi_update/main.rs new file mode 100644 index 00000000..147e4c38 --- /dev/null +++ b/my_optims/tgi_update/main.rs @@ -0,0 +1,1268 @@ +use clap::{Parser, ValueEnum}; +use nix::sys::signal::{self, Signal}; +use nix::unistd::Pid; +use serde::Deserialize; +use std::env; +use std::ffi::OsString; +use std::io::{BufRead, BufReader, Lines, Read}; +use std::os::unix::process::{CommandExt, ExitStatusExt}; +use std::path::Path; +use std::process::{Child, Command, ExitStatus, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::TryRecvError; +use std::sync::{mpsc, Arc}; +use std::thread; +use std::thread::sleep; +use std::time::{Duration, Instant}; +use std::{fs, io}; +use tracing_subscriber::EnvFilter; + +mod env_runtime; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Quantization { + /// 4 bit quantization. Requires a specific GTPQ quantized model: + /// https://hf.co/models?search=awq. + /// Should replace GPTQ models whereever possible because of the better latency + Awq, + /// 8 bit quantization, doesn't require specific model. + /// Should be a drop-in replacement to bitsandbytes with much better performance. + /// Kernels are from https://github.com/NetEase-FuXi/EETQ.git + Eetq, + /// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. + /// text-generation-inference will use exllama (faster) kernels whereever possible, and use + /// triton kernel (wider support) when it's not. + /// AWQ has faster kernels. + Gptq, + /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, + /// but it is known that the model will be much slower to run than the native f16. + #[deprecated( + since = "1.1.0", + note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + )] + Bitsandbytes, + /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, + /// but it is known that the model will be much slower to run than the native f16. + BitsandbytesNF4, + /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better + /// perplexity performance for you model + BitsandbytesFP4, +} + +impl std::fmt::Display for Quantization { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Quantization::Bitsandbytes => { + write!(f, "bitsandbytes") + } + Quantization::BitsandbytesNF4 => { + write!(f, "bitsandbytes-nf4") + } + Quantization::BitsandbytesFP4 => { + write!(f, "bitsandbytes-fp4") + } + Quantization::Gptq => { + write!(f, "gptq") + } + Quantization::Awq => { + write!(f, "awq") + } + Quantization::Eetq => { + write!(f, "eetq") + } + } + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Dtype { + Float16, + #[clap(name = "bfloat16")] + BFloat16, +} + +impl std::fmt::Display for Dtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Dtype::Float16 => { + write!(f, "float16") + } + Dtype::BFloat16 => { + write!(f, "bfloat16") + } + } + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum RopeScaling { + Linear, + Dynamic, +} + +impl std::fmt::Display for RopeScaling { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + RopeScaling::Linear => { + write!(f, "linear") + } + RopeScaling::Dynamic => { + write!(f, "dynamic") + } + } + } +} + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + /// The name of the model to load. + /// Can be a MODEL_ID as listed on like + /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. + /// Or it can be a local directory containing the necessary files + /// as saved by `save_pretrained(...)` methods of transformers + #[clap(default_value = "bigscience/bloom-560m", long, env)] + model_id: String, + + /// The actual revision of the model if you're referring to a model + /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. + #[clap(long, env)] + revision: Option, + + /// The number of tokenizer workers used for payload validation and truncation inside the + /// router. + #[clap(default_value = "2", long, env)] + validation_workers: usize, + + /// Whether to shard the model across multiple GPUs + /// By default text-generation-inference will use all available GPUs to run + /// the model. Setting it to `false` deactivates `num_shard`. + #[clap(long, env)] + sharded: Option, + + /// The number of shards to use if you don't want to use all GPUs on a given machine. + /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` + /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to + /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance. + #[clap(long, env)] + num_shard: Option, + + /// Whether you want the model to be quantized. + #[clap(long, env, value_enum)] + quantize: Option, + + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. + #[clap(long, env, value_enum)] + dtype: Option, + + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is + /// encouraged when loading a model with custom code to ensure no malicious code has been + /// contributed in a newer revision. + #[clap(long, env, value_enum)] + trust_remote_code: bool, + + /// The maximum amount of concurrent requests for this particular deployment. + /// Having a low limit will refuse clients requests instead of having them + /// wait for too long and is usually good to handle backpressure correctly. + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + + /// This is the maximum allowed value for clients to set `best_of`. + /// Best of makes `n` generations at the same time, and return the best + /// in terms of overall log probability over the entire generated sequence + #[clap(default_value = "2", long, env)] + max_best_of: usize, + + /// This is the maximum allowed value for clients to set `stop_sequences`. + /// Stop sequences are used to allow the model to stop on more than just + /// the EOS token, and enable more complex "prompting" where users can preprompt + /// the model in a specific way and define their "own" stop token aligned with + /// their prompt. + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + + /// This is the maximum allowed value for clients to set `top_n_tokens`. + /// `top_n_tokens is used to return information about the the `n` most likely + /// tokens at each generation step, instead of just the sampled token. This + /// information can be used for downstream tasks like for classification or + /// ranking. + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + + /// This is the maximum allowed input length (expressed in number of tokens) + /// for users. The larger this value, the longer prompt users can send which + /// can impact the overall memory required to handle the load. + /// Please note that some models have a finite range of sequence they can handle. + #[clap(default_value = "1024", long, env)] + max_input_length: usize, + + /// This is the most important value to set as it defines the "memory budget" + /// of running clients requests. + /// Clients will send input sequences and ask to generate `max_new_tokens` + /// on top. with a value of `1512` users can send either a prompt of + /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for + /// `1511` max_new_tokens. + /// The larger this value, the larger amount each request will be in your RAM + /// and the less effective batching can be. + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + + /// This represents the ratio of waiting queries vs running queries where + /// you want to start considering pausing the running queries to include the waiting + /// ones into the same batch. + /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's + /// only 10 queries left in the current batch we check if we can fit those 12 + /// waiting queries into the batching strategy, and if yes, then batching happens + /// delaying the 10 running queries by a `prefill` run. + /// + /// This setting is only applied if there is room in the batch + /// as defined by `max_batch_total_tokens`. + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + + /// Limits the number of tokens for the prefill operation. + /// Since this operation take the most memory and is compute bound, it is interesting + /// to limit the number of requests that can be sent. + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + + /// **IMPORTANT** This is one critical control to allow maximum usage + /// of the available hardware. + /// + /// This represents the total amount of potential tokens within a batch. + /// When using padding (not recommended) this would be equivalent of + /// `batch_size` * `max_total_tokens`. + /// + /// However in the non-padded (flash attention) version this can be much finer. + /// + /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` + /// or a single query of `1000` tokens. + /// + /// Overall this number should be the largest possible amount that fits the + /// remaining memory (after the model is loaded). Since the actual memory overhead + /// depends on other parameters like if you're using quantization, flash attention + /// or the model implementation, text-generation-inference cannot infer this number + /// automatically. + #[clap(long, env)] + max_batch_total_tokens: Option, + + /// This setting defines how many tokens can be passed before forcing the waiting + /// queries to be put on the batch (if the size of the batch allows for it). + /// New queries require 1 `prefill` forward, which is different from `decode` + /// and therefore you need to pause the running batch in order to run `prefill` + /// to create the correct values for the waiting queries to be able to join the batch. + /// + /// With a value too small, queries will always "steal" the compute to run `prefill` + /// and running queries will be delayed by a lot. + /// + /// With a value too big, waiting queries could wait for a very long time + /// before being allowed a slot in the running batch. If your server is busy + /// that means that requests that could run in ~2s on an empty server could + /// end up running in ~20s because the query had to wait for 18s. + /// + /// This number is expressed in number of tokens to make it a bit more + /// "model" agnostic, but what should really matter is the overall latency + /// for end users. + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + + /// The IP address to listen on + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + + /// The port to listen on. + #[clap(default_value = "3000", long, short, env)] + port: u16, + + /// The name of the socket for gRPC communication between the webserver + /// and the shards. + #[clap(default_value = "/tmp/text-generation-server", long, env)] + shard_uds_path: String, + + /// The address the master shard will listen on. (setting used by torch distributed) + #[clap(default_value = "localhost", long, env)] + master_addr: String, + + /// The address the master port will listen on. (setting used by torch distributed) + #[clap(default_value = "29500", long, env)] + master_port: usize, + + /// The location of the huggingface hub cache. + /// Used to override the location if you want to provide a mounted disk for instance + #[clap(long, env)] + huggingface_hub_cache: Option, + + /// The location of the huggingface hub cache. + /// Used to override the location if you want to provide a mounted disk for instance + #[clap(long, env)] + weights_cache_override: Option, + + /// For some models (like bloom), text-generation-inference implemented custom + /// cuda kernels to speed up inference. Those kernels were only tested on A100. + /// Use this flag to disable them if you're running on different hardware and + /// encounter issues. + #[clap(long, env)] + disable_custom_kernels: bool, + + /// Limit the CUDA available memory. + /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction. + #[clap(default_value = "1.0", long, env)] + cuda_memory_fraction: f32, + + /// Rope scaling will only be used for RoPE models + /// and allow rescaling the position rotary to accomodate for + /// larger prompts. + /// + /// Goes together with `rope_factor`. + /// + /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0 + /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 + /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed + /// basically) + /// + /// `--rope-scaling linear --rope-factor` fully describes the scaling you want + #[clap(long, env)] + rope_scaling: Option, + + /// Rope scaling will only be used for RoPE models + /// See `rope_scaling` + #[clap(long, env)] + rope_factor: Option, + + /// Outputs the logs in JSON format (useful for telemetry) + #[clap(long, env)] + json_output: bool, + + #[clap(long, env)] + otlp_endpoint: Option, + + #[clap(long, env)] + cors_allow_origin: Vec, + #[clap(long, env)] + watermark_gamma: Option, + #[clap(long, env)] + watermark_delta: Option, + + /// Enable ngrok tunneling + #[clap(long, env)] + ngrok: bool, + + /// ngrok authentication token + #[clap(long, env)] + ngrok_authtoken: Option, + + /// ngrok edge + #[clap(long, env)] + ngrok_edge: Option, + + /// Display a lot of information about your runtime environment + #[clap(long, short, action)] + env: bool, +} + +#[derive(Debug)] +enum ShardStatus { + Ready, + Failed(usize), +} + +#[allow(clippy::too_many_arguments)] +fn shard_manager( + model_id: String, + revision: Option, + quantize: Option, + dtype: Option, + trust_remote_code: bool, + uds_path: String, + rank: usize, + world_size: usize, + master_addr: String, + master_port: usize, + huggingface_hub_cache: Option, + weights_cache_override: Option, + disable_custom_kernels: bool, + watermark_gamma: Option, + watermark_delta: Option, + cuda_memory_fraction: f32, + rope_scaling: Option, + rope_factor: Option, + otlp_endpoint: Option, + status_sender: mpsc::Sender, + shutdown: Arc, + _shutdown_sender: mpsc::Sender<()>, +) { + // Enter shard-manager tracing span + let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); + + // Get UDS path + let uds_string = format!("{uds_path}-{rank}"); + let uds = Path::new(&uds_string); + // Clean previous runs + if uds.exists() { + fs::remove_file(uds).unwrap(); + } + + // Process args + let mut shard_args = vec![ + "serve".to_string(), + model_id, + "--uds-path".to_string(), + uds_path, + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + + // Activate trust remote code + if trust_remote_code { + shard_args.push("--trust-remote-code".to_string()); + } + + // Activate tensor parallelism + if world_size > 1 { + shard_args.push("--sharded".to_string()); + } + + if let Some(quantize) = quantize { + shard_args.push("--quantize".to_string()); + shard_args.push(quantize.to_string()) + } + + if let Some(dtype) = dtype { + shard_args.push("--dtype".to_string()); + shard_args.push(dtype.to_string()) + } + + // Model optional revision + if let Some(revision) = revision { + shard_args.push("--revision".to_string()); + shard_args.push(revision) + } + + let rope = match (rope_scaling, rope_factor) { + (None, None) => None, + (Some(scaling), None) => Some((scaling, 1.0)), + (Some(scaling), Some(factor)) => Some((scaling, factor)), + (None, Some(factor)) => Some((RopeScaling::Linear, factor)), + }; + // OpenTelemetry + if let Some(otlp_endpoint) = otlp_endpoint { + shard_args.push("--otlp-endpoint".to_string()); + shard_args.push(otlp_endpoint); + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Torch Distributed Env vars + envs.push(("RANK".into(), rank.to_string().into())); + envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); + envs.push(("MASTER_ADDR".into(), master_addr.into())); + envs.push(("MASTER_PORT".into(), master_port.to_string().into())); + envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); + + // CUDA memory fraction + envs.push(( + "CUDA_MEMORY_FRACTION".into(), + cuda_memory_fraction.to_string().into(), + )); + + // Safetensors load fast + envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + envs.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // Detect rope scaling + // Sending as env instead of CLI args to not bloat everything + // those only can be used by RoPE models, so passing information around + // for all models will complexify code unnecessarily + if let Some((scaling, factor)) = rope { + envs.push(("ROPE_SCALING".into(), scaling.to_string().into())); + envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); + } + + // If huggingface_hub_cache is some, pass it to the shard + // Useful when running inside a docker container + if let Some(huggingface_hub_cache) = huggingface_hub_cache { + envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // If weights_cache_override is some, pass it to the shard + // Useful when running inside a HuggingFace Inference Endpoint + if let Some(weights_cache_override) = weights_cache_override { + envs.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); + }; + + // If disable_custom_kernels is true, pass it to the shard as an env var + if disable_custom_kernels { + envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) + } + + // Watermark Gamma + if let Some(watermark_gamma) = watermark_gamma { + envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) + } + + // Watermark Delta + if let Some(watermark_delta) = watermark_delta { + envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) + } + // for mpi + let n_devices = match num_cuda_devices() { + Some(value) => value.to_string(), + None => String::from("0"), + }; + // Start process + tracing::info!("Starting shard"); + tracing::info!("run with mpi and use cuda device num: {}", n_devices); + // tracing::info!("{:?}", shard_args); + tracing::info!("{:?}", envs); + let mut p = match Command::new("mpirun") + .args(&["-n", &n_devices, "--allow-run-as-root", "text-generation-server"]) + .args(shard_args) + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("start mpi failed! "); + } + { + tracing::error!("{}", err); + } + + status_sender.send(ShardStatus::Failed(rank)).unwrap(); + return; + } + }; + + // Redirect STDOUT to the console + let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); + let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + + //stdout tracing thread + thread::spawn(move || { + log_lines(shard_stdout_reader.lines()); + }); + + let mut ready = false; + let start_time = Instant::now(); + let mut wait_time = Instant::now(); + loop { + // Process exited + if let Some(exit_status) = p.try_wait().unwrap() { + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in shard_stderr_reader.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } + }); + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } + + tracing::error!("Shard complete standard error output:\n{err}"); + + if let Some(signal) = exit_status.signal() { + tracing::error!("Shard process was signaled to shutdown with signal {signal}"); + } + + status_sender.send(ShardStatus::Failed(rank)).unwrap(); + return; + } + + // We received a shutdown signal + if shutdown.load(Ordering::SeqCst) { + p.kill().unwrap(); + let _ = p.wait(); + tracing::info!("Shard terminated"); + return; + } + + // Shard is ready + if uds.exists() && !ready { + tracing::info!("Shard ready in {:?}", start_time.elapsed()); + status_sender.send(ShardStatus::Ready).unwrap(); + ready = true; + } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { + tracing::info!("Waiting for shard to be ready..."); + wait_time = Instant::now(); + } + sleep(Duration::from_millis(100)); + } +} + +fn shutdown_shards(shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>) { + tracing::info!("Shutting down shards"); + // Update shutdown value to true + // This will be picked up by the shard manager + shutdown.store(true, Ordering::SeqCst); + + // Wait for shards to shutdown + // This will block till all shutdown_sender are dropped + let _ = shutdown_receiver.recv(); +} + +fn num_cuda_devices() -> Option { + let devices = match env::var("CUDA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + }; + let n_devices = devices.split(',').count(); + Some(n_devices) +} + +#[derive(Deserialize)] +#[serde(rename_all = "UPPERCASE")] +enum PythonLogLevelEnum { + Trace, + Debug, + Info, + Success, + Warning, + Error, + Critical, +} + +#[derive(Deserialize)] +struct PythonLogLevel { + name: PythonLogLevelEnum, +} + +#[derive(Deserialize)] +struct PythonLogRecord { + level: PythonLogLevel, +} + +#[derive(Deserialize)] +struct PythonLogMessage { + text: String, + record: PythonLogRecord, +} + +impl PythonLogMessage { + fn trace(&self) { + match self.record.level.name { + PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), + PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), + PythonLogLevelEnum::Info => tracing::info!("{}", self.text), + PythonLogLevelEnum::Success => tracing::info!("{}", self.text), + PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), + PythonLogLevelEnum::Error => tracing::error!("{}", self.text), + PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), + } + } +} + +impl TryFrom<&String> for PythonLogMessage { + type Error = serde_json::Error; + + fn try_from(value: &String) -> Result { + serde_json::from_str::(value) + } +} + +fn log_lines(lines: Lines) { + for line in lines.flatten() { + match PythonLogMessage::try_from(&line) { + Ok(log) => log.trace(), + Err(_) => tracing::debug!("{line}"), + } + } +} + +fn find_num_shards( + sharded: Option, + num_shard: Option, +) -> Result { + // get the number of shards given `sharded` and `num_shard` + let num_shard = match (sharded, num_shard) { + (Some(true), None) => { + // try to default to the number of available GPUs + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); + let n_devices = num_cuda_devices() + .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); + if n_devices <= 1 { + return Err(LauncherError::NotEnoughCUDADevices(format!( + "`sharded` is true but only found {n_devices} CUDA devices" + ))); + } + n_devices + } + (Some(true), Some(num_shard)) => { + // we can't have only one shard while sharded + if num_shard <= 1 { + return Err(LauncherError::ArgumentValidation( + "`sharded` is true but `num_shard` <= 1".to_string(), + )); + } + num_shard + } + (Some(false), Some(num_shard)) => num_shard, + (Some(false), None) => 1, + (None, None) => num_cuda_devices().unwrap_or(1), + (None, Some(num_shard)) => num_shard, + }; + if num_shard < 1 { + return Err(LauncherError::ArgumentValidation( + "`num_shard` cannot be < 1".to_string(), + )); + } + Ok(num_shard) +} + +#[derive(Debug)] +enum LauncherError { + ArgumentValidation(String), + NotEnoughCUDADevices(String), + DownloadError, + ShardCannotStart, + ShardDisconnected, + ShardFailed, + WebserverFailed, + WebserverCannotStart, +} + +fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { + // Enter download tracing span + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + + let mut download_args = vec![ + "download-weights".to_string(), + args.model_id.to_string(), + "--extension".to_string(), + ".safetensors".to_string(), + "--logger-level".to_string(), + "INFO".to_string(), + "--json-output".to_string(), + ]; + + // Model optional revision + if let Some(revision) = &args.revision { + download_args.push("--revision".to_string()); + download_args.push(revision.to_string()) + } + + // Trust remote code for automatic peft fusion + if args.trust_remote_code { + download_args.push("--trust-remote-code".to_string()); + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // If huggingface_hub_cache is set, pass it to the download process + // Useful when running inside a docker container + if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { + envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + }; + + // Enable hf transfer for insane download speeds + let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); + envs.push(( + "HF_HUB_ENABLE_HF_TRANSFER".into(), + enable_hf_transfer.into(), + )); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + // If args.weights_cache_override is some, pass it to the download process + // Useful when running inside a HuggingFace Inference Endpoint + if let Some(weights_cache_override) = &args.weights_cache_override { + envs.push(( + "WEIGHTS_CACHE_OVERRIDE".into(), + weights_cache_override.into(), + )); + }; + + // Start process + tracing::info!("Starting download process."); + let mut download_process = match Command::new("text-generation-server") + .args(download_args) + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } else { + tracing::error!("{}", err); + } + + return Err(LauncherError::DownloadError); + } + }; + + // Redirect STDOUT to the console + let download_stdout = download_process.stdout.take().unwrap(); + let stdout = BufReader::new(download_stdout); + + thread::spawn(move || { + log_lines(stdout.lines()); + }); + + loop { + if let Some(status) = download_process.try_wait().unwrap() { + if status.success() { + tracing::info!("Successfully downloaded weights."); + break; + } + + let mut err = String::new(); + download_process + .stderr + .take() + .unwrap() + .read_to_string(&mut err) + .unwrap(); + if let Some(signal) = status.signal() { + tracing::error!( + "Download process was signaled to shutdown with signal {signal}: {err}" + ); + } else { + tracing::error!("Download encountered an error: {err}"); + } + + return Err(LauncherError::DownloadError); + } + if !running.load(Ordering::SeqCst) { + terminate("download", download_process, Duration::from_secs(10)).unwrap(); + return Ok(()); + } + sleep(Duration::from_millis(100)); + } + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn spawn_shards( + num_shard: usize, + args: &Args, + shutdown: Arc, + shutdown_receiver: &mpsc::Receiver<()>, + shutdown_sender: mpsc::Sender<()>, + status_receiver: &mpsc::Receiver, + status_sender: mpsc::Sender, + running: Arc, +) -> Result<(), LauncherError> { + // Start shard processes + for rank in 0..num_shard { + let model_id = args.model_id.clone(); + let revision = args.revision.clone(); + let uds_path = args.shard_uds_path.clone(); + let master_addr = args.master_addr.clone(); + let huggingface_hub_cache = args.huggingface_hub_cache.clone(); + let weights_cache_override = args.weights_cache_override.clone(); + let status_sender = status_sender.clone(); + let shutdown = shutdown.clone(); + let shutdown_sender = shutdown_sender.clone(); + let otlp_endpoint = args.otlp_endpoint.clone(); + let quantize = args.quantize; + let dtype = args.dtype; + let trust_remote_code = args.trust_remote_code; + let master_port = args.master_port; + let disable_custom_kernels = args.disable_custom_kernels; + let watermark_gamma = args.watermark_gamma; + let watermark_delta = args.watermark_delta; + let cuda_memory_fraction = args.cuda_memory_fraction; + let rope_scaling = args.rope_scaling; + let rope_factor = args.rope_factor; + thread::spawn(move || { + shard_manager( + model_id, + revision, + quantize, + dtype, + trust_remote_code, + uds_path, + rank, + num_shard, + master_addr, + master_port, + huggingface_hub_cache, + weights_cache_override, + disable_custom_kernels, + watermark_gamma, + watermark_delta, + cuda_memory_fraction, + rope_scaling, + rope_factor, + otlp_endpoint, + status_sender, + shutdown, + shutdown_sender, + ) + }); + } + drop(shutdown_sender); + + // Wait for shard to start + let mut shard_ready = 0; + while running.load(Ordering::SeqCst) { + match status_receiver.try_recv() { + Ok(ShardStatus::Ready) => { + shard_ready += 1; + if shard_ready == num_shard { + break; + } + } + Err(TryRecvError::Empty) => { + sleep(Duration::from_millis(100)); + } + Ok(ShardStatus::Failed(rank)) => { + tracing::error!("Shard {rank} failed to start"); + shutdown_shards(shutdown, shutdown_receiver); + return Err(LauncherError::ShardCannotStart); + } + Err(TryRecvError::Disconnected) => { + tracing::error!("Shard status channel disconnected"); + shutdown_shards(shutdown, shutdown_receiver); + return Err(LauncherError::ShardDisconnected); + } + } + } + Ok(()) +} + +fn spawn_webserver( + args: Args, + shutdown: Arc, + shutdown_receiver: &mpsc::Receiver<()>, +) -> Result { + // All shard started + // Start webserver + tracing::info!("Starting Webserver"); + let mut router_args = vec![ + "--max-concurrent-requests".to_string(), + args.max_concurrent_requests.to_string(), + "--max-best-of".to_string(), + args.max_best_of.to_string(), + "--max-stop-sequences".to_string(), + args.max_stop_sequences.to_string(), + "--max-top-n-tokens".to_string(), + args.max_top_n_tokens.to_string(), + "--max-input-length".to_string(), + args.max_input_length.to_string(), + "--max-total-tokens".to_string(), + args.max_total_tokens.to_string(), + "--max-batch-prefill-tokens".to_string(), + args.max_batch_prefill_tokens.to_string(), + "--waiting-served-ratio".to_string(), + args.waiting_served_ratio.to_string(), + "--max-waiting-tokens".to_string(), + args.max_waiting_tokens.to_string(), + "--validation-workers".to_string(), + args.validation_workers.to_string(), + "--hostname".to_string(), + args.hostname.to_string(), + "--port".to_string(), + args.port.to_string(), + "--master-shard-uds-path".to_string(), + format!("{}-0", args.shard_uds_path), + "--tokenizer-name".to_string(), + args.model_id, + ]; + + // Model optional max batch total tokens + if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { + router_args.push("--max-batch-total-tokens".to_string()); + router_args.push(max_batch_total_tokens.to_string()); + } + + // Model optional revision + if let Some(ref revision) = args.revision { + router_args.push("--revision".to_string()); + router_args.push(revision.to_string()) + } + + if args.json_output { + router_args.push("--json-output".to_string()); + } + + // OpenTelemetry + if let Some(otlp_endpoint) = args.otlp_endpoint { + router_args.push("--otlp-endpoint".to_string()); + router_args.push(otlp_endpoint); + } + + // CORS origins + for origin in args.cors_allow_origin.into_iter() { + router_args.push("--cors-allow-origin".to_string()); + router_args.push(origin); + } + + // Ngrok + if args.ngrok { + router_args.push("--ngrok".to_string()); + router_args.push("--ngrok-authtoken".to_string()); + router_args.push(args.ngrok_authtoken.unwrap()); + router_args.push("--ngrok-edge".to_string()); + router_args.push(args.ngrok_edge.unwrap()); + } + + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + + // Parse Inference API token + if let Ok(api_token) = env::var("HF_API_TOKEN") { + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + }; + + let mut webserver = match Command::new("text-generation-router") + .args(router_args) + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() + { + Ok(p) => p, + Err(err) => { + tracing::error!("Failed to start webserver: {}", err); + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-router not found in PATH"); + tracing::error!("Please install it with `make install-router`") + } else { + tracing::error!("{}", err); + } + + shutdown_shards(shutdown, shutdown_receiver); + return Err(LauncherError::WebserverCannotStart); + } + }; + + // Redirect STDOUT and STDERR to the console + let webserver_stdout = webserver.stdout.take().unwrap(); + let webserver_stderr = webserver.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(webserver_stdout); + let stderr = BufReader::new(webserver_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + Ok(webserver) +} + +fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result { + tracing::info!("Terminating {process_name}"); + + let terminate_time = Instant::now(); + signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); + + tracing::info!("Waiting for {process_name} to gracefully shutdown"); + + while terminate_time.elapsed() < timeout { + if let Some(status) = process.try_wait()? { + tracing::info!("{process_name} terminated"); + return Ok(status); + } + sleep(Duration::from_millis(100)); + } + + tracing::info!("Killing {process_name}"); + + process.kill()?; + let exit_status = process.wait()?; + + tracing::info!("{process_name} killed"); + Ok(exit_status) +} + +fn main() -> Result<(), LauncherError> { + // Pattern match configuration + let args: Args = Args::parse(); + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + + if args.json_output { + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .json() + .init(); + } else { + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .compact() + .init(); + } + + if args.env { + let env_runtime = env_runtime::Env::new(); + tracing::info!("{}", env_runtime); + } + + tracing::info!("{:?}", args); + + // Validate args + if args.max_input_length >= args.max_total_tokens { + return Err(LauncherError::ArgumentValidation( + "`max_input_length` must be < `max_total_tokens`".to_string(), + )); + } + if args.max_input_length as u32 > args.max_batch_prefill_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", + args.max_batch_prefill_tokens, args.max_input_length + ))); + } + + if args.validation_workers == 0 { + return Err(LauncherError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + if args.trust_remote_code { + tracing::warn!( + "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", + args.model_id + ); + } + + let num_shard = find_num_shards(args.sharded, args.num_shard)?; + if num_shard > 1 { + tracing::info!("Sharding model on {num_shard} processes"); + } + + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if args.max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_batch_prefill_tokens, max_batch_total_tokens + ))); + } + if args.max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_total_tokens, max_batch_total_tokens + ))); + } + } + + if args.ngrok { + if args.ngrok_authtoken.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), + )); + } + + if args.ngrok_edge.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-edge` must be set when using ngrok tunneling".to_string(), + )); + } + } + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Download and convert model weights + download_convert_model(&args, running.clone())?; + + if !running.load(Ordering::SeqCst) { + // Launcher was asked to stop + return Ok(()); + } + + // Shared shutdown bool + let shutdown = Arc::new(AtomicBool::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + spawn_shards( + num_shard, + &args, + shutdown.clone(), + &shutdown_receiver, + shutdown_sender, + &status_receiver, + status_sender, + running.clone(), + )?; + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return Ok(()); + } + + let mut webserver = + spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| { + shutdown_shards(shutdown.clone(), &shutdown_receiver); + err + })?; + + // Default exit code + let mut exit_code = Ok(()); + + while running.load(Ordering::SeqCst) { + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { + tracing::error!("Shard {rank} crashed"); + exit_code = Err(LauncherError::ShardFailed); + break; + }; + + match webserver.try_wait().unwrap() { + Some(_) => { + tracing::error!("Webserver Crashed"); + shutdown_shards(shutdown, &shutdown_receiver); + return Err(LauncherError::WebserverFailed); + } + None => { + sleep(Duration::from_millis(100)); + } + }; + } + + // Graceful termination + terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); + shutdown_shards(shutdown, &shutdown_receiver); + + exit_code +} diff --git a/my_optims/tgi_update/my_dist.py b/my_optims/tgi_update/my_dist.py new file mode 100644 index 00000000..a01da46a --- /dev/null +++ b/my_optims/tgi_update/my_dist.py @@ -0,0 +1,47 @@ +import os +import torch + +from datetime import timedelta +from loguru import logger +from mpi4py import MPI +import my_custom_comm + +# Tensor Parallelism settings +RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) +WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + +# CUDA memory fraction +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + +class MyCommGroup: + def __init__(self, rank, size, tp_comm, pp_comm): + self._rank = rank + self._size = size + self.tp_comm = tp_comm + self.pp_comm = pp_comm + + def size(self): + return self._size + + def rank(self): + return self._rank + + +def initialize_mpi_distributed(): + assert torch.cuda.is_available() + # mpi initialize + COMM = MPI.COMM_WORLD + assert COMM.Get_size() == WORLD_SIZE, f"{COMM.Get_size()},{WORLD_SIZE}" + + # Set the device id. + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() + torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + + # nccl initialize + tp_comm, pp_comm = my_custom_comm.init_nccl(WORLD_SIZE, 1) + process_group = MyCommGroup(RANK, WORLD_SIZE, tp_comm, pp_comm) + + logger.warning("custom mpi and nccl is already initialized.") + return process_group, RANK, WORLD_SIZE, COMM diff --git a/my_optims/tgi_update/server.py b/my_optims/tgi_update/server.py new file mode 100644 index 00000000..5848d3fa --- /dev/null +++ b/my_optims/tgi_update/server.py @@ -0,0 +1,217 @@ +import asyncio +import os +import torch + +from grpc import aio +from loguru import logger + +from grpc_reflection.v1alpha import reflection +from pathlib import Path +from typing import List, Optional + +from text_generation_server.cache import Cache +from text_generation_server.interceptor import ExceptionInterceptor +from text_generation_server.models import Model, get_model +from text_generation_server.pb import generate_pb2_grpc, generate_pb2 +from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor +from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + + +class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): + def __init__(self, model: Model, cache: Cache, server_urls: List[str]): + self.cache = cache + self.model = model + self.server_urls = server_urls + # For some reason, inference_mode does not work well with GLOO which we use on CPU + if model.device.type == "cuda": + # Force inference mode for the lifetime of TextGenerationService + self._inference_mode_raii_guard = torch._C._InferenceMode(True) + + async def Info(self, request, context): + return self.model.info + + async def Health(self, request, context): + if self.model.device.type == "cuda": + torch.zeros((2, 2)).cuda() + return generate_pb2.HealthResponse() + + async def ServiceDiscovery(self, request, context): + return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) + + async def ClearCache(self, request, context): + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + return generate_pb2.ClearCacheResponse() + + async def FilterBatch(self, request, context): + batch = self.cache.pop(request.batch_id) + if batch is None: + raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + filtered_batch = batch.filter(request.request_ids) + self.cache.set(filtered_batch) + + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + + async def Warmup(self, request, context): + if ( + self.model.batch_type == IdeficsCausalLMBatch + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + max_supported_total_tokens = self.model.warmup(batch) + + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) + + async def Prefill(self, request, context): + if ( + self.model.batch_type == IdeficsCausalLMBatch + ): # Hack, i would rather use kwargs in the `from_pb` call + batch = self.model.batch_type.from_pb( + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.dtype, + self.model.device, + ) + else: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + + generations, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) + + async def Decode(self, request, context): + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") + + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") + + if len(batches) > 1: + batch = self.model.batch_type.concatenate(batches) + else: + batch = batches[0] + + generations, next_batch = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + ) + + +def serve( + model_id: str, + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + dtype: Optional[str], + trust_remote_code: bool, + uds_path: Path, +): + async def serve_inner( + model_id: str, + revision: Optional[str], + sharded: bool = False, + quantize: Optional[str] = None, + dtype: Optional[str] = None, + trust_remote_code: bool = False, + ): + logger.info(os.environ) + unix_socket_template = "unix://{}-{}" + if sharded: + server_urls = [ + unix_socket_template.format(uds_path, rank) + for rank in range(int(os.environ["WORLD_SIZE"])) + ] + local_url = server_urls[int(os.environ["RANK"])] + else: + local_url = unix_socket_template.format(uds_path, 0) + server_urls = [local_url] + + if int(os.environ.get("USE_CUSTOM_NCCL", 0)): + server_urls = [ + unix_socket_template.format(uds_path, rank) + for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"])) + ] + local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])] + + try: + model = get_model( + model_id, revision, sharded, quantize, dtype, trust_remote_code + ) + except Exception: + logger.exception("Error when initializing model") + raise + + if quantize == "gptq": + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.gptq.exllama import ( + create_exllama_buffers, + set_device, + ) + + set_device(model.device) + create_exllama_buffers() + except ImportError: + pass + + server = aio.server( + interceptors=[ + ExceptionInterceptor(), + UDSOpenTelemetryAioServerInterceptor(), + ] + ) + generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( + TextGenerationService(model, Cache(), server_urls), server + ) + SERVICE_NAMES = ( + generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + server.add_insecure_port(local_url) + + await server.start() + + logger.info("Server started at {}".format(local_url)) + + try: + await server.wait_for_termination() + except KeyboardInterrupt: + logger.info("Signal received. Shutting down") + await server.stop(0) + + asyncio.run( + serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + ) \ No newline at end of file diff --git a/my_optims/tp_optims/flash_llama_modeling.py b/my_optims/tp_optims/flash_llama_modeling.py new file mode 100644 index 00000000..080b2171 --- /dev/null +++ b/my_optims/tp_optims/flash_llama_modeling.py @@ -0,0 +1,467 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Tuple + +# Flash attention imports +import dropout_layer_norm +import flash_attn_2_cuda +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +# vllm imports +from vllm import attention_ops, cache_ops +from torch.nn import functional as F + +from layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, +) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize != "gptq": + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + # self.rotary_emb = PositionRotaryEmbedding.load( + # config=config, prefix=f"{prefix}.rotary_emb", weights=weights + # ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_size, base=config.rope_theta, device=weights.device + ) + + self.softmax_scale = self.head_size ** -0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + cache_ops.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn_2_cuda.varlen_fwd( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + cu_seqlen_prefill, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + -1, + 0, + False, + None, + ) + # Decode + else: + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + attention_ops.paged_attention_v1( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class LlamaMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashLlamaLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashLlamaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + # self.embed_tokens = TensorParallelEmbedding( + # prefix="model.embed_tokens", weights=weights + # ) + embeddings = weights.get_tensor(f"model.embed_tokens.weight") + self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)), + padding_idx=config.pad_token_id) + + self.layers = nn.ModuleList( + [ + FlashLlamaLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + # for layer_id in range(1) + ] + ) + self.norm = LlamaRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashLlamaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/my_optims/tp_optims/layers.py b/my_optims/tp_optims/layers.py new file mode 100644 index 00000000..b7c70661 --- /dev/null +++ b/my_optims/tp_optims/layers.py @@ -0,0 +1,402 @@ +import os +from typing import List + +import torch +# import torch.distributed +from torch import nn +from torch.nn import functional as F +import my_custom_comm + +class FastLinear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + +def get_linear(weight, bias, quantize): + linear = FastLinear(weight, bias) + return linear + + +class SuperLayer(nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group, should_gather: bool): + super().__init__(linear) + self.process_group = process_group + self.should_gather = should_gather + + @staticmethod + def load(config, prefix: str, weights): + if weights.process_group.size() > 1: + try: + assert False + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + + # GPTQ doesn't quantize heads (nor embeddings) + if config.quantize == "gptq": + quantize = None + else: + quantize = config.quantize + return TensorParallelHead( + get_linear(weight, bias=None, quantize=quantize), + process_group=weights.process_group, + should_gather=should_gather, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.should_gather: + return super().forward(input) + + # world_size = self.process_group.size() + # if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + # out_dim = self.linear.weight.shape[0] + + # if input.shape[0] == 1: + # world_out = input.new_empty(1, out_dim * world_size) + # local_out = input.new_empty(1, out_dim) + # gather_input = local_out + # else: + # world_out = input.new_empty(out_dim * world_size, input.shape[0]) + # gather_input = input.new_empty(out_dim, input.shape[0]) + # local_out = gather_input.T + + # torch.mm(input, self.linear.weight.T, out=local_out) + + # torch.distributed.all_gather_into_tensor( + # world_out, gather_input, group=self.process_group + # ) + + # if input.shape[0] == 1: + # return world_out + # return world_out.T + + # output = super().forward(input) + # world_output = [ + # torch.empty_like(output) for _ in range(self.process_group.size()) + # ] + # torch.distributed.all_gather(world_output, output, group=self.process_group) + # world_output = torch.cat(world_output, dim=-1) + # return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + return cls.load_multi(config, [prefix], weights, bias, dim=0) + + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + weight = weights.get_multi_weights_col( + prefixes, quantize=config.quantize, dim=dim + ) + + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=dim) + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) + self.process_group = process_group + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super().forward(input) + if self.process_group.size() > 1: + # torch.distributed.all_reduce(out, group=self.process_group) + my_custom_comm.custom_allreduce(out, self.process_group.tp_comm) + return out + + +class TensorParallelEmbedding(nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = num_embeddings // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = block_size + self.process_group = weights.process_group + self.reduce = reduce + + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = torch.nn.functional.embedding(input, self.weight) + if self.reduce and self.process_group.size() > 1: + # torch.distributed.all_reduce(out, group=self.process_group) + my_custom_comm.custom_allreduce(out, self.process_group.tp_comm) + return out + + +try: + import dropout_layer_norm + + + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +except ImportError: + pass + +try: + from flash_attn.layers.rotary import RotaryEmbedding + import rotary_emb + + + def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + + + def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} + return rope_scaling + return getattr(config, "rope_scaling", None) + + + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() + self.inv_freq = inv_freq + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None + + @classmethod + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, device=inv_freq.device, + scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) + + @classmethod + def load(cls, config, prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, device=inv_freq.device, + scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim: 2 * rotary_dim] + + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x + + + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - ( + self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + +except ImportError: + pass diff --git a/my_optims/tp_optims/test_llama_fa_tp.py b/my_optims/tp_optims/test_llama_fa_tp.py new file mode 100644 index 00000000..5de90a42 --- /dev/null +++ b/my_optims/tp_optims/test_llama_fa_tp.py @@ -0,0 +1,277 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- test_llama_fa.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/18-20:50 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import math +import time +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch +import transformers + +from flash_llama_modeling import FlashLlamaForCausalLM +from weights import Weights +from mpi4py import MPI +import my_custom_comm + +COMM = None +BLOCK_SIZE = 16 + + +class FakeGroup: + def __init__(self, rank, size, tp_comm, pp_comm): + self._rank = rank + self._size = size + self.tp_comm = tp_comm + self.pp_comm = pp_comm + + def size(self): + return self._size + + def rank(self): + return self._rank + + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks + self.device = device + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate(self, blocks, max_blocks, needed_blocks_slots): + """ + blocks: 总共需要的blocks数量 + max_blocks: 最大的blocks数量大小 + needed_blocks_slots: 每个序列所需的blocks及其对应的序列长度 + """ + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= blocks + ), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[: blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(needed_blocks_slots), max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks: cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + return block_tables, block_tables_tensor.to(self.device), torch.concat(slots).to(self.device) + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + + +def generate(tokenizer, model, config, device, prompt, max_new_tokens=10): + input_ids = tokenizer(prompt).input_ids + + def warmup(): + print("start warmup...") + global CACHE_MANAGER + blocks = 260 + CACHE_MANAGER = CacheManager(blocks, + len(model.model.layers), + model.model.num_key_value_heads, + model.model.head_size, + torch.float16, + device) + input_length = 1024 + bs = 4 + warmup_inputs = { + 'input_ids': torch.arange(1, input_length + 1, dtype=torch.int64, device=device).repeat(bs), + 'position_ids': torch.arange(0, input_length, dtype=torch.int32, device=device).repeat(bs), + 'cu_seqlen_prefill': torch.tensor([i * input_length for i in range(bs + 1)], dtype=torch.int32, + device=device), + 'block_tables': torch.arange(0, blocks, dtype=torch.int32, device=device).split(blocks // bs), + 'slots': torch.arange(0, 4144, dtype=torch.int32, device=device), + 'input_lengths': torch.tensor([input_length] * 4, dtype=torch.int32, device=device), + 'max_s': 1024, + 'lm_head_indices': None + } + model.forward(**warmup_inputs, kv_cache=CACHE_MANAGER.kv_cache) + + del CACHE_MANAGER + torch.cuda.empty_cache() + + # 预热 + warmup() + + print("start speed test running") + # 申请缓存空间 + global CACHE_MANAGER + CACHE_MANAGER = CacheManager(100, + len(model.model.layers), + model.model.num_key_value_heads, + model.model.head_size, + torch.float16, + device) + total_tokens = len(input_ids) + max_new_tokens - 1 + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks_slots = [(needed_blocks, total_tokens)] + _, block_tables_tensor, slots = CACHE_MANAGER.allocate(needed_blocks, needed_blocks, needed_blocks_slots) + # forward循环 + tpss = [] + loops = 10 + for loop in range(loops): + print(f"loop {loop}...") + times = [] + new_tokens = [] + for step in range(max_new_tokens): + if step == 0: + # prefill step + slot_indices = torch.arange(0, 0 + len(input_ids), dtype=torch.int64) + inputs = { + 'input_ids': torch.tensor(input_ids, dtype=torch.int64, device=device), + 'position_ids': torch.arange(0, len(input_ids), dtype=torch.int32, device=device), + 'cu_seqlen_prefill': torch.tensor([0, len(input_ids)], dtype=torch.int32, device=device), + 'block_tables': block_tables_tensor, + 'slots': slots[slot_indices], + 'input_lengths': torch.tensor([len(input_ids)], dtype=torch.int32, device=device), + 'max_s': len(input_ids), + 'lm_head_indices': torch.tensor([0 + len(input_ids) - 1], dtype=torch.int32, device=device) + } + else: + # incremental step + current_length = len(input_ids) + step + inputs = { + 'input_ids': new_tokens[-1], + 'position_ids': torch.tensor([current_length - 1], dtype=torch.int32, device=device), + 'cu_seqlen_prefill': None, + 'block_tables': block_tables_tensor, + 'slots': torch.tensor([current_length - 1], dtype=torch.int32, device=device), + 'input_lengths': torch.tensor([current_length], dtype=torch.int32, device=device), + 'max_s': current_length, + 'lm_head_indices': None + } + torch.cuda.synchronize() + s_time = time.time() + logits = model.forward(**inputs, kv_cache=CACHE_MANAGER.kv_cache) + torch.cuda.synchronize() + cost_time = time.time() - s_time + next_token_id = logits.argmax(dim=-1) + new_tokens.append(next_token_id) + times.append(round(cost_time, 6)) + + if loop == 0: + new_tokens = torch.concat(new_tokens) + print(tokenizer.decode(new_tokens, skip_special_tokens=True)) + + elapsed_time = np.mean(times) + tps = 1 / elapsed_time + tpss.append(tps) + print(times) + print(f"total new tokens: {max_new_tokens}, cost time: {sum(times):.6f} s\n" + f"time_per_token: {elapsed_time * 1000:.3f} ms, tps: {tps:.2f} tokens/s") + print(f'mean tps: {np.mean(tpss):.2f} tokens/s') + +def init_dist_env(): + global COMM + COMM = MPI.COMM_WORLD + rank = COMM.Get_rank() + world_size = COMM.Get_size() + tp_ptr, pp_ptr = my_custom_comm.init_nccl(2, 1) + device = rank % torch.cuda.device_count() + torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(1., device) + process_group = FakeGroup(rank, world_size, tp_ptr, pp_ptr) + return process_group, rank, world_size + + +def main(model_path): + # init env + process_group, rank, world_size = init_dist_env() + print(f'{rank=},{world_size=},{process_group.__dict__=}') + # step 0: 定义路径与属性 + model_path = Path(model_path) + config = transformers.AutoConfig.from_pretrained(model_path) + config.quantize = None + model_files = list(model_path.glob('*.safetensors')) + + # step 1: 定义tokenizer与权重 + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="left", truncation_side="left") + device = torch.device(f"cuda:{rank}") + weights = Weights(model_files, device, torch.float16, process_group=process_group) + + # step2: 定义模型 + COMM.barrier() + model = FlashLlamaForCausalLM(config, weights).eval() + COMM.barrier() + print(model) + + # step3: 推理 + with torch.no_grad(): + prompt = "who are you?" + generate(tokenizer, model, config, device, prompt, max_new_tokens=100) + my_custom_comm.finalize_nccl(process_group.tp_comm, process_group.pp_comm) + + + +if __name__ == '__main__': + CACHE_MANAGER: Optional[CacheManager] = None + main('/code/models/llama-7b-hf') diff --git a/my_optims/tp_optims/utils.py b/my_optims/tp_optims/utils.py new file mode 100644 index 00000000..bcdb46bb --- /dev/null +++ b/my_optims/tp_optims/utils.py @@ -0,0 +1,100 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- utils.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/27-14:43 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import os +import time +from datetime import timedelta + +import torch +from loguru import logger +from torch.distributed import ProcessGroupNCCL + +RANK = int(os.getenv("LOCAL_RANK", "0")) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", "2")) +NCCL_PORT = int(os.getenv("NCCL_PORT", "29500")) +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + + +class FakeBarrier: + def wait(self): + pass + + +class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def allreduce(self, *args, **kwargs): + return FakeBarrier() + + def allgather(self, inputs, local_tensor, **kwargs): + assert ( + len(inputs[0]) == len(local_tensor) == 1 + ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + for input_ in inputs: + input_[0].data = local_tensor[0].data + return FakeBarrier() + + def barrier(self, *args, **kwargs): + return FakeBarrier() + + def size(self): + return self._size + + def rank(self): + return self._rank + + +def initialize_torch_distributed(): + # Set the device id. + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() + torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + backend = "nccl" + options = ProcessGroupNCCL.Options() + options.is_high_priority_stream = True + options._timeout = timedelta(seconds=60) + if not torch.distributed.is_initialized(): + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + init_method=f"tcp://localhost:{NCCL_PORT}", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + logger.info(f"torch.distributed is initialized on rank {RANK} of {WORLD_SIZE}.") + else: + logger.warning("torch.distributed is already initialized.") + + return torch.distributed.group.WORLD, RANK, WORLD_SIZE + + +class Timer: + def __init__(self): + self.times = [] + self.time = None + + def start(self): + torch.cuda.synchronize() + self.time = time.time() + + def end(self): + torch.cuda.synchronize() + self.times.append(time.time() - self.time) + + @property + def elapsed(self): + self.times.pop(0) + return round(sum(self.times) / len(self.times) * 1000, 2) diff --git a/my_optims/tp_optims/weights.py b/my_optims/tp_optims/weights.py new file mode 100644 index 00000000..dbd4d45d --- /dev/null +++ b/my_optims/tp_optims/weights.py @@ -0,0 +1,336 @@ +# encoding:utf-8 +# -------------------------------------------# +# Filename: optims -- weights.py +# +# Description: +# Version: 1.0 +# Created: 2023/9/27-15:05 +# Last modified by: +# Author: 'zhaohuayang@myhexin.com' +# Company: 同花顺网络信息股份有限公司 +# -------------------------------------------# +import json +import os +from pathlib import Path +from typing import List, Dict, Optional, Tuple + +import torch +from huggingface_hub import hf_hub_download +from loguru import logger +from safetensors import safe_open, SafetensorError + + +class Weights: + def __init__( + self, + filenames: List[Path], + device, + dtype, + process_group, + aliases: Optional[Dict[str, List[str]]] = None, + ): + routing = {} + for filename in filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self._handles = {} + + def _get_handle(self, filename): + if filename not in self._handles: + f = safe_open(filename, framework="pytorch") + self._handles[filename] = f + + return self._handles[filename] + + def get_filename(self, tensor_name: str) -> (str, str): + filename = self.routing.get(tensor_name, None) + if filename is None: + aliases = self.aliases.get(tensor_name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias + raise RuntimeError(f"weight {tensor_name} does not exist") + return str(filename), tensor_name + + def _get_slice(self, tensor_name: str): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + return slice_ + + def get_shape(self, tensor_name: str): + return self._get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str, to_device=True): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype not in [torch.int32, torch.int64]: + tensor = tensor.to(dtype=self.dtype) + if to_device: + tensor = tensor.to(device=self.device) + return tensor + + def get_partial_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + if dim == 0: + tensor = slice_[start:stop] + elif dim == 1: + tensor = slice_[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype != torch.int32: + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + + def _get_qweight(self, name: str): + slice_ = self._get_slice(name) + total_size = slice_.get_shape()[1] + assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[:, start:stop] + k = slice_[:, start + single_size:stop + single_size] + v = slice_[:, start + 2 * single_size:stop + 2 * single_size] + weight = torch.cat([q, k, v], dim=1) + weight = weight.to(device=self.device) + return weight + + def get_weights_col_packed_qkv(self, prefix: str, quantize: str): + """ + Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being + already alternating Q,K,V within the main tensor + """ + if quantize == "gptq": + try: + qweight = self._get_qweight(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") + scales = scales.to(dtype=self.dtype) + g_idx = self.get_tensor(f"{prefix}.g_idx") + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + slice_ = self._get_slice(f"{prefix}.weight") + total_size = slice_.get_shape()[0] + assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[start:stop] + k = slice_[start + single_size:stop + single_size] + v = slice_[start + 2 * single_size:stop + 2 * single_size] + weight = torch.cat([q, k, v], dim=0) + weight = weight.to(device=self.device) + weight = weight.to(dtype=self.dtype) + return weight + + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): + if quantize == "gptq": + try: + qweight = torch.cat( + [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = torch.cat( + [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + scales = torch.cat( + [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + weight = torch.cat(w, dim=dim) + return weight + + def get_tensor_shard(self, var, dim): + world_size = self.process_group.size() + rank = self.process_group.rank() + block_size = var.size()[dim] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + if dim == 0: + tensor = var[start:stop] + elif dim == 1: + tensor = var[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_multi_weights_row(self, prefix: str, quantize: str): + if quantize == "gptq": + use_exllama = True + bits, groupsize = self._get_gptq_params() + + if bits != 4: + use_exllama = False + + if self.process_group.size() > 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + logger.warning( + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" + ) + use_exllama = False + else: + logger.info("Using exllama kernels") + + if use_exllama: + if groupsize >= 0: + # Exllama reorders the weights in advance and the activations on the fly, thus + # the scales and zero-points do not need to be reordered. + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + + # For tp > 1, at this point we know we do not use act-order + if self.process_group.size() == 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + else: + # The triton kernel reorders the scales/zero points instead of the weight/activation. + # Thus, each rank needs the full qzeros/scales. + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + else: + weight = self.get_sharded(f"{prefix}.weight", dim=1) + return weight + + def _get_gptq_params(self) -> Tuple[int, int]: + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except (SafetensorError, RuntimeError) as e: + try: + bits = self.gptq_bits + groupsize = self.gptq_groupsize + except Exception: + raise e + + return bits, groupsize + + def _set_gptq_params(self, model_id): + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["quantization_config"]["bits"] + self.gptq_groupsize = data["quantization_config"]["group_size"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["bits"] + self.gptq_groupsize = data["group_size"] + except Exception: + pass diff --git a/server/text_generation_server/utils/my_dist.py b/server/text_generation_server/utils/my_dist.py new file mode 100644 index 00000000..a01da46a --- /dev/null +++ b/server/text_generation_server/utils/my_dist.py @@ -0,0 +1,47 @@ +import os +import torch + +from datetime import timedelta +from loguru import logger +from mpi4py import MPI +import my_custom_comm + +# Tensor Parallelism settings +RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) +WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + +# CUDA memory fraction +MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) + +class MyCommGroup: + def __init__(self, rank, size, tp_comm, pp_comm): + self._rank = rank + self._size = size + self.tp_comm = tp_comm + self.pp_comm = pp_comm + + def size(self): + return self._size + + def rank(self): + return self._rank + + +def initialize_mpi_distributed(): + assert torch.cuda.is_available() + # mpi initialize + COMM = MPI.COMM_WORLD + assert COMM.Get_size() == WORLD_SIZE, f"{COMM.Get_size()},{WORLD_SIZE}" + + # Set the device id. + assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" + device = RANK % torch.cuda.device_count() + torch.cuda.set_device(device) + torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) + + # nccl initialize + tp_comm, pp_comm = my_custom_comm.init_nccl(WORLD_SIZE, 1) + process_group = MyCommGroup(RANK, WORLD_SIZE, tp_comm, pp_comm) + + logger.warning("custom mpi and nccl is already initialized.") + return process_group, RANK, WORLD_SIZE, COMM