mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
tp optims
This commit is contained in:
parent
dad29f7299
commit
2fd1156d5a
17
my_optims/nccl_test/CMakeLists.txt
Normal file
17
my_optims/nccl_test/CMakeLists.txt
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
add_executable(test_nccl test_nccl.cc)
|
||||
target_link_libraries(test_nccl PUBLIC -lcublas -lcublasLt -lcudart
|
||||
nvtx_utils mpi_utils nccl_utils memory_utils)
|
171
my_optims/nccl_test/my_custom_comm.cc
Normal file
171
my_optims/nccl_test/my_custom_comm.cc
Normal file
@ -0,0 +1,171 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <mpi.h>
|
||||
#include <nccl.h>
|
||||
#include <string>
|
||||
#include <torch/extension.h>
|
||||
|
||||
struct NcclParam {
|
||||
int rank_{0};
|
||||
int world_size_{1};
|
||||
ncclUniqueId nccl_uid_;
|
||||
ncclComm_t nccl_comm_ = nullptr;
|
||||
cudaStream_t stream_;
|
||||
|
||||
NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){};
|
||||
NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){};
|
||||
NcclParam(NcclParam const& param):
|
||||
rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_), stream_(param.stream_){};
|
||||
};
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
std::tuple<NcclParam*, NcclParam*> init_nccl(int tensor_para_size, int pipeline_para_size)
|
||||
{
|
||||
// int argc = 0;
|
||||
// char** argv = nullptr;
|
||||
|
||||
// // 初始化 MPI
|
||||
// MPI_Init(&argc, &argv);
|
||||
int rank, world_size;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank); // 获取当前进程的 rank
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size); // 获取进程总数
|
||||
// printf("rank:%d, world_size:%d\n", rank, world_size);
|
||||
|
||||
// 设定gpu
|
||||
int device, device_count;
|
||||
cudaGetDeviceCount(&device_count);
|
||||
cudaSetDevice(rank % device_count);
|
||||
cudaGetDevice(&device);
|
||||
struct cudaDeviceProp prop;
|
||||
cudaGetDeviceProperties(&prop, device);
|
||||
|
||||
int mpi_initialized;
|
||||
MPI_Initialized(&mpi_initialized);
|
||||
|
||||
static NcclParam tensor_para;
|
||||
static NcclParam pipeline_para;
|
||||
// Convert WORLD communicator into 2D grid (k * n) communicator.
|
||||
// row = a tensor parallel group, col = a pipeline parallel group.
|
||||
MPI_Comm grid_comm, tp_comm, pp_comm;
|
||||
|
||||
int dims[2] = {pipeline_para_size, tensor_para_size};
|
||||
int periods[2] = {0, 0};
|
||||
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm);
|
||||
|
||||
// Split 2D communicator into rows and cols.
|
||||
int tp_remain_dims[2] = {false, true};
|
||||
int pp_remain_dims[2] = {true, false};
|
||||
MPI_Cart_sub(grid_comm, tp_remain_dims, &tp_comm);
|
||||
MPI_Cart_sub(grid_comm, pp_remain_dims, &pp_comm);
|
||||
|
||||
int tp_rank, pp_rank;
|
||||
MPI_Comm_rank(tp_comm, &tp_rank);
|
||||
MPI_Comm_rank(pp_comm, &pp_rank);
|
||||
printf("tp_rank:%d, pp_rank:%d\n", tp_rank, pp_rank);
|
||||
|
||||
ncclUniqueId tp_uid;
|
||||
ncclUniqueId pp_uid;
|
||||
// The root of each group creates a nccl uid.
|
||||
if (tp_rank == 0) {
|
||||
NCCLCHECK(ncclGetUniqueId(&tp_uid));
|
||||
}
|
||||
if (pp_rank == 0) {
|
||||
NCCLCHECK(ncclGetUniqueId(&pp_uid));
|
||||
}
|
||||
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
|
||||
MPI_Bcast(&tp_uid, sizeof(tp_uid), MPI_BYTE, 0, tp_comm);
|
||||
MPI_Bcast(&pp_uid, sizeof(pp_uid), MPI_BYTE, 0, pp_comm);
|
||||
|
||||
ncclComm_t tp_nccl_comm, pp_nccl_comm;
|
||||
NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank));
|
||||
NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank));
|
||||
|
||||
tensor_para.world_size_ = tensor_para_size;
|
||||
tensor_para.rank_ = tp_rank;
|
||||
tensor_para.nccl_uid_ = tp_uid;
|
||||
tensor_para.nccl_comm_ = tp_nccl_comm;
|
||||
cudaStreamCreate(&tensor_para.stream_);
|
||||
pipeline_para.world_size_ = pipeline_para_size;
|
||||
pipeline_para.rank_ = pp_rank;
|
||||
pipeline_para.nccl_uid_ = pp_uid;
|
||||
pipeline_para.nccl_comm_ = pp_nccl_comm;
|
||||
|
||||
NcclParam* tensor_para_ptr = &tensor_para;
|
||||
NcclParam* pipeline_para_ptr = &pipeline_para;
|
||||
|
||||
return std::make_tuple(tensor_para_ptr, pipeline_para_ptr);
|
||||
}
|
||||
|
||||
void finalize_nccl(NcclParam* tensor_para_ptr, NcclParam* pipeline_para_ptr)
|
||||
{
|
||||
// 销毁nccl
|
||||
NcclParam tensor_para = *tensor_para_ptr;
|
||||
NcclParam pipeline_para = *pipeline_para_ptr;
|
||||
if (tensor_para.nccl_comm_ != nullptr) {
|
||||
ncclCommDestroy(tensor_para.nccl_comm_);
|
||||
}
|
||||
if (pipeline_para.nccl_comm_ != nullptr) {
|
||||
ncclCommDestroy(pipeline_para.nccl_comm_);
|
||||
}
|
||||
// MPI_Finalize();
|
||||
}
|
||||
|
||||
ncclDataType_t getNcclDataType(torch::ScalarType torch_type)
|
||||
{
|
||||
ncclDataType_t nccl_type;
|
||||
if (torch_type == torch::kFloat16) {
|
||||
nccl_type = ncclHalf;
|
||||
}
|
||||
else if (torch_type == torch::kFloat32) {
|
||||
nccl_type = ncclFloat;
|
||||
}
|
||||
else if (torch_type == torch::kFloat64) {
|
||||
nccl_type = ncclDouble;
|
||||
}
|
||||
else if (torch_type == torch::kInt32) {
|
||||
nccl_type = ncclInt32;
|
||||
}
|
||||
else if (torch_type == torch::kInt64) {
|
||||
nccl_type = ncclInt64;
|
||||
}
|
||||
else if (torch_type == torch::kInt8) {
|
||||
nccl_type = ncclInt8;
|
||||
}
|
||||
else {
|
||||
printf("[ERROR] NCCL only support float, half, int \n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return nccl_type;
|
||||
}
|
||||
void custom_allreduce(torch::Tensor tensor, NcclParam* nccl_param_ptr)
|
||||
{
|
||||
void* data_ptr = tensor.data_ptr();
|
||||
size_t count = tensor.numel();
|
||||
torch::ScalarType torch_type = tensor.scalar_type();
|
||||
|
||||
NcclParam nccl_param = *nccl_param_ptr;
|
||||
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
ncclDataType_t nccl_data_type = getNcclDataType(torch_type);
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
NCCLCHECK(ncclAllReduce(
|
||||
(const void*)data_ptr, (void*)data_ptr, count, nccl_data_type, ncclSum, nccl_param.nccl_comm_, nccl_param.stream_));
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
py::class_<NcclParam>(m, "NcclParam").def(py::init<>());
|
||||
m.def("init_nccl", &init_nccl, py::return_value_policy::reference, "");
|
||||
m.def("finalize_nccl", &finalize_nccl, "");
|
||||
m.def("custom_allreduce", &custom_allreduce, "");
|
||||
}
|
12
my_optims/nccl_test/setup.py
Normal file
12
my_optims/nccl_test/setup.py
Normal file
@ -0,0 +1,12 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(name='my_custom_comm',
|
||||
include_dirs=["/usr/local/cuda/targets/x86_64-linux/include/",
|
||||
],
|
||||
ext_modules=[CUDAExtension('my_custom_comm',
|
||||
['my_custom_comm.cc'],
|
||||
libraries=["mpi"]
|
||||
),],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
)
|
22
my_optims/nccl_test/test_extension.py
Normal file
22
my_optims/nccl_test/test_extension.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import my_custom_comm
|
||||
from mpi4py import MPI
|
||||
|
||||
COMM = MPI.COMM_WORLD
|
||||
rank = COMM.Get_rank()
|
||||
world_size = COMM.Get_size()
|
||||
tp_ptr, pp_ptr = my_custom_comm.init_nccl(2, 1)
|
||||
print(tp_ptr)
|
||||
device = rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(1., device)
|
||||
print(rank, world_size)
|
||||
|
||||
t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]],
|
||||
dtype=torch.float16).to('cuda')
|
||||
|
||||
print(my_custom_comm.custom_allreduce(t, tp_ptr))
|
||||
print(t)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)
|
32
my_optims/nccl_test/test_extension_torch.py
Normal file
32
my_optims/nccl_test/test_extension_torch.py
Normal file
@ -0,0 +1,32 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import my_custom_comm
|
||||
import os
|
||||
assert dist.is_mpi_available()
|
||||
|
||||
# rank, world_size = int(os.getenv('RANK')), int(os.getenv('WORLD_SIZE'))
|
||||
print(os.environ)
|
||||
dist.init_process_group(backend=dist.Backend.MPI,
|
||||
# rank=rank,
|
||||
# world_size=2
|
||||
)
|
||||
assert dist.is_initialized()
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
print(f'{rank=},{world_size=}')
|
||||
|
||||
# tp_ptr, pp_ptr = my_custom_comm.init_nccl(2, 1)
|
||||
# print(tp_ptr)
|
||||
# device = rank % torch.cuda.device_count()
|
||||
# torch.cuda.set_device(device)
|
||||
# torch.cuda.set_per_process_memory_fraction(1., device)
|
||||
# print(rank, world_size)
|
||||
|
||||
# t = torch.tensor([[1, 2, 3, 4], [3, 3, 3, 3.1]],
|
||||
# dtype=torch.float16).to('cuda')
|
||||
|
||||
# print(my_custom_comm.custom_allreduce(t, tp_ptr))
|
||||
# print(t)
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
# my_custom_comm.finalize_nccl(tp_ptr, pp_ptr)
|
88
my_optims/nccl_test/test_nccl.cc
Normal file
88
my_optims/nccl_test/test_nccl.cc
Normal file
@ -0,0 +1,88 @@
|
||||
#include "src/fastertransformer/utils/mpi_utils.h"
|
||||
#include "src/fastertransformer/utils/nccl_utils.h"
|
||||
#include "src/fastertransformer/utils/nvtx_utils.h"
|
||||
#include "src/fastertransformer/utils/memory_utils.h"
|
||||
#include <string>
|
||||
#include <cuda_profiler_api.h>
|
||||
|
||||
|
||||
using namespace fastertransformer;
|
||||
|
||||
template<typename T>
|
||||
void test_nccl();
|
||||
|
||||
int main(int argc, char **argv){
|
||||
mpi::initialize(&argc, &argv);
|
||||
|
||||
test_nccl<half>();
|
||||
|
||||
mpi::finalize();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void test_nccl()
|
||||
{
|
||||
// int rank = 0;
|
||||
// int world_size = 1;
|
||||
int rank = mpi::getCommWorldRank();
|
||||
int world_size = mpi::getCommWorldSize();
|
||||
if (rank == 0) {
|
||||
printf("Total ranks: %d.\n", world_size);
|
||||
}
|
||||
int device, device_count;
|
||||
check_cuda_error(cudaGetDeviceCount(&device_count));
|
||||
check_cuda_error(cudaSetDevice(rank % device_count));
|
||||
check_cuda_error(cudaGetDevice(&device));
|
||||
struct cudaDeviceProp prop;
|
||||
check_cuda_error(cudaGetDeviceProperties(&prop, device));
|
||||
printf("Device %s\n", prop.name);
|
||||
printf("P%d is running with GPU #%d.\n", rank, device);
|
||||
print_mem_usage();
|
||||
|
||||
// 声明数据指针
|
||||
std::vector<T*> weights_ptr = std::vector<T*>(1);
|
||||
size_t shape = 4096 * 2048;
|
||||
deviceMalloc(&weights_ptr[0], shape, true); // 从gpu中开辟空间,并随即初始化
|
||||
// deviceFill(weights_ptr[0], (size_t)shape, (T)2.0); // 给gpu空间赋值
|
||||
|
||||
// 初始化nccl
|
||||
int tensor_para_size = 2;
|
||||
int pipeline_para_size = 1;
|
||||
NcclParam tensor_para;
|
||||
NcclParam pipeline_para;
|
||||
ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size);
|
||||
std::cout << "tensor_para info:" << tensor_para.rank_ << std::endl;
|
||||
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
mpi::barrier();
|
||||
cudaProfilerStart();
|
||||
ft_nvtx::setScope("run_time");
|
||||
PUSH_RANGE("run time")
|
||||
for (int i = 0; i < 32; i++) {
|
||||
cudaDeviceSynchronize();
|
||||
ftNcclAllReduceSum(weights_ptr[0], weights_ptr[0], shape, tensor_para, stream);
|
||||
cudaDeviceSynchronize();
|
||||
deviceMalloc(&weights_ptr[0], shape, true);
|
||||
}
|
||||
|
||||
mpi::barrier();
|
||||
POP_RANGE;
|
||||
ft_nvtx::resetScope();
|
||||
|
||||
// T* hBuf = new T[shape]; // 开辟CPU空间
|
||||
// cudaD2Hcpy(hBuf, weights_ptr[0], shape);
|
||||
// { // 从cpu打印,查看数据
|
||||
// for (size_t i = 0; i < shape; i++) {
|
||||
// printf("%f ", (float)hBuf[i]);
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
// }
|
||||
// delete[] hBuf;
|
||||
|
||||
return;
|
||||
}
|
30
my_optims/nccl_test/torch_nccl.py
Normal file
30
my_optims/nccl_test/torch_nccl.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
def test_nccl():
|
||||
rank, world_size = int(os.getenv('RANK')), int(os.getenv('WORLD_SIZE'))
|
||||
dist.init_process_group(backend='nccl',
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
process_group = dist.group.WORLD
|
||||
torch.cuda.set_device(rank % world_size)
|
||||
shape = 4096 * 2048
|
||||
weight = torch.randn([shape], dtype=torch.float16).to("cuda")
|
||||
|
||||
# nccl test
|
||||
dist.barrier(process_group)
|
||||
for i in range(32):
|
||||
torch.cuda.synchronize()
|
||||
dist.all_reduce(weight, group=process_group)
|
||||
torch.cuda.synchronize()
|
||||
weight = torch.randn([shape], dtype=torch.float16).to("cuda")
|
||||
|
||||
dist.barrier(process_group)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_nccl()
|
709
my_optims/optims/test_llama_fa.py
Normal file
709
my_optims/optims/test_llama_fa.py
Normal file
@ -0,0 +1,709 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- test_llama_fa.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/18-20:50
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from typing import Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
import flash_attn_2_cuda
|
||||
import numpy as np
|
||||
import rotary_emb
|
||||
import torch
|
||||
import torch.distributed
|
||||
import transformers
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.activations import ACT2FN
|
||||
from vllm import attention_ops, cache_ops
|
||||
|
||||
# vllm imports
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
class FastLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
@staticmethod
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = cls._create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim: 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
filenames: List[Path],
|
||||
device,
|
||||
dtype,
|
||||
process_group,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.aliases = aliases
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename):
|
||||
if filename not in self._handles:
|
||||
f = safe_open(filename, framework="pytorch")
|
||||
self._handles[filename] = f
|
||||
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
for alias in aliases:
|
||||
filename = self.routing.get(alias, None)
|
||||
if filename is not None:
|
||||
return str(filename), alias
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
return str(filename), tensor_name
|
||||
|
||||
def get_tensor(self, tensor_name: str, to_device=True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def load_multi_linear(self, config, prefixes):
|
||||
weight = torch.cat([self.get_tensor(f"{p}.weight") for p in prefixes], dim=0)
|
||||
return FastLinear(weight, bias=None)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
self.softmax_scale = self.head_size ** -0.5
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
# q,k,v,o and rotary
|
||||
self.query_key_value = weights.load_multi_linear(config,
|
||||
[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"])
|
||||
self.o_proj = FastLinear.load(config, f"{prefix}.o_proj", weights, bias=False)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
-1,
|
||||
0,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
attention_ops.paged_attention_v1(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.act = ACT2FN[act]
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = weights.load_multi_linear(config, [f"{prefix}.gate_proj", f"{prefix}.up_proj"])
|
||||
self.down_proj = FastLinear.load(config, f"{prefix}.down_proj", weights, bias=False)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
|
||||
self.input_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
|
||||
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
|
||||
padding_idx=config.pad_token_id)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
# for layer_id in range(1)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = FastLinear.load(config, prefix="lm_head", weights=weights, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
self.device = device
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(self, blocks, max_blocks, needed_blocks_slots):
|
||||
"""
|
||||
blocks: 总共需要的blocks数量
|
||||
max_blocks: 最大的blocks数量大小
|
||||
needed_blocks_slots: 每个序列所需的blocks及其对应的序列长度
|
||||
"""
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= blocks
|
||||
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[: blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks: cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
cumulative_blocks += needed_blocks
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
return block_tables, block_tables_tensor.to(self.device), torch.concat(slots).to(self.device)
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
def generate(tokenizer, model, prompt, max_new_tokens=10):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
||||
def warmup():
|
||||
print("start warmup...")
|
||||
global CACHE_MANAGER
|
||||
blocks = 260
|
||||
CACHE_MANAGER = CacheManager(blocks,
|
||||
model.config.num_hidden_layers,
|
||||
model.config.num_key_value_heads,
|
||||
model.config.hidden_size // model.config.num_attention_heads,
|
||||
torch.float16,
|
||||
device)
|
||||
input_length = 1024
|
||||
bs = 4
|
||||
warmup_inputs = {
|
||||
'input_ids': torch.arange(1, input_length + 1, dtype=torch.int64, device=device).repeat(bs),
|
||||
'position_ids': torch.arange(0, input_length, dtype=torch.int32, device=device).repeat(bs),
|
||||
'cu_seqlen_prefill': torch.tensor([i * input_length for i in range(bs + 1)], dtype=torch.int32,
|
||||
device=device),
|
||||
'block_tables': torch.arange(0, blocks, dtype=torch.int32, device=device).split(blocks // bs),
|
||||
'slots': torch.arange(0, 4144, dtype=torch.int32, device=device),
|
||||
'input_lengths': torch.tensor([input_length] * 4, dtype=torch.int32, device=device),
|
||||
'max_s': 1024,
|
||||
'lm_head_indices': None
|
||||
}
|
||||
model.forward(**warmup_inputs, kv_cache=CACHE_MANAGER.kv_cache)
|
||||
|
||||
del CACHE_MANAGER
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 预热
|
||||
warmup()
|
||||
|
||||
print("start speed test running")
|
||||
# 申请缓存空间
|
||||
global CACHE_MANAGER
|
||||
CACHE_MANAGER = CacheManager(100,
|
||||
model.config.num_hidden_layers,
|
||||
model.config.num_key_value_heads,
|
||||
model.config.hidden_size // model.config.num_attention_heads,
|
||||
torch.float16,
|
||||
device)
|
||||
total_tokens = len(input_ids) + max_new_tokens - 1
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
needed_blocks_slots = [(needed_blocks, total_tokens)]
|
||||
_, block_tables_tensor, slots = CACHE_MANAGER.allocate(needed_blocks, needed_blocks, needed_blocks_slots)
|
||||
# forward循环
|
||||
loops = 10
|
||||
for loop in range(loops):
|
||||
print(f"loop {loop}...")
|
||||
times = []
|
||||
new_tokens = []
|
||||
for step in range(max_new_tokens):
|
||||
if step == 0:
|
||||
# prefill step
|
||||
slot_indices = torch.arange(0, 0 + len(input_ids), dtype=torch.int64)
|
||||
inputs = {
|
||||
'input_ids': torch.tensor(input_ids, dtype=torch.int64, device=device),
|
||||
'position_ids': torch.arange(0, len(input_ids), dtype=torch.int32, device=device),
|
||||
'cu_seqlen_prefill': torch.tensor([0, len(input_ids)], dtype=torch.int32, device=device),
|
||||
'block_tables': block_tables_tensor,
|
||||
'slots': slots[slot_indices],
|
||||
'input_lengths': torch.tensor([len(input_ids)], dtype=torch.int32, device=device),
|
||||
'max_s': len(input_ids),
|
||||
'lm_head_indices': torch.tensor([0 + len(input_ids) - 1], dtype=torch.int32, device=device)
|
||||
}
|
||||
else:
|
||||
# incremental step
|
||||
current_length = len(input_ids) + step
|
||||
inputs = {
|
||||
'input_ids': new_tokens[-1],
|
||||
'position_ids': torch.tensor([current_length - 1], dtype=torch.int32, device=device),
|
||||
'cu_seqlen_prefill': None,
|
||||
'block_tables': block_tables_tensor,
|
||||
'slots': torch.tensor([current_length - 1], dtype=torch.int32, device=device),
|
||||
'input_lengths': torch.tensor([current_length], dtype=torch.int32, device=device),
|
||||
'max_s': current_length,
|
||||
'lm_head_indices': None
|
||||
}
|
||||
torch.cuda.synchronize()
|
||||
s_time = time.time()
|
||||
logits = model.forward(**inputs, kv_cache=CACHE_MANAGER.kv_cache)
|
||||
torch.cuda.synchronize()
|
||||
cost_time = time.time() - s_time
|
||||
next_token_id = logits.argmax(dim=-1)
|
||||
new_tokens.append(next_token_id)
|
||||
times.append(round(cost_time, 6))
|
||||
|
||||
if loop == 0:
|
||||
new_tokens = torch.concat(new_tokens)
|
||||
print(tokenizer.decode(new_tokens, skip_special_tokens=True))
|
||||
|
||||
elapsed_time = np.mean(times)
|
||||
print(f"total new tokens: {max_new_tokens}, cost time: {sum(times):.6f} s\n"
|
||||
f"time_per_token: {elapsed_time * 1000:.3f} ms, tps: {1 / elapsed_time:.2f} tokens/s")
|
||||
|
||||
|
||||
def main(model_path):
|
||||
# step 0: 定义路径与属性
|
||||
model_path = Path(model_path)
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
model_files = list(model_path.glob('*.safetensors'))
|
||||
|
||||
# step 1: 定义tokenizer与权重
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="left", truncation_side="left")
|
||||
weights = Weights(model_files, device, torch.float16, process_group=None)
|
||||
|
||||
# step2: 定义模型
|
||||
model = FlashLlamaForCausalLM(config, weights).eval()
|
||||
print(model)
|
||||
|
||||
# step3: 推理
|
||||
with torch.no_grad():
|
||||
prompt = "who are you?"
|
||||
generate(tokenizer, model, prompt, max_new_tokens=100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
CACHE_MANAGER: Optional[CacheManager] = None
|
||||
device = torch.device("cuda")
|
||||
main('/code/models/llama-7b-hf')
|
42
my_optims/optims/test_llama_hf.py
Normal file
42
my_optims/optims/test_llama_hf.py
Normal file
@ -0,0 +1,42 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: aime-local-inference-server -- test_llama_hf.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/7-15:19
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import time
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
def test_llama():
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained("/code/models/llama-7b-hf")
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained("/code/models/llama-7b-hf", torch_dtype=torch.float16,
|
||||
device_map="auto").half().eval()
|
||||
prompt = "who are you?"
|
||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
|
||||
# warm up
|
||||
output_ids = model.generate(input_ids, max_new_tokens=100)
|
||||
i_l = len(input_ids[0])
|
||||
o_l = len(output_ids[0])
|
||||
print(f'input length: {i_l}\t'
|
||||
f'output length: {o_l}')
|
||||
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
|
||||
# tps
|
||||
loop = 10
|
||||
s_time = time.time()
|
||||
for _ in range(loop):
|
||||
model.generate(input_ids, max_new_tokens=100)
|
||||
mean_ = (time.time() - s_time) / loop
|
||||
|
||||
print(f"tps: {(o_l - i_l) / mean_:.4f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_llama()
|
467
my_optims/optims/test_tp/flash_llama_modeling.py
Normal file
467
my_optims/optims/test_tp/flash_llama_modeling.py
Normal file
@ -0,0 +1,467 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
import flash_attn_2_cuda
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
# vllm imports
|
||||
from vllm import attention_ops, cache_ops
|
||||
from torch.nn import functional as F
|
||||
|
||||
from layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize != "gptq":
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** -0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
-1,
|
||||
0,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
attention_ops.paged_attention_v1(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
# self.embed_tokens = TensorParallelEmbedding(
|
||||
# prefix="model.embed_tokens", weights=weights
|
||||
# )
|
||||
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
|
||||
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
|
||||
padding_idx=config.pad_token_id)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
# for layer_id in range(1)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
399
my_optims/optims/test_tp/layers.py
Normal file
399
my_optims/optims/test_tp/layers.py
Normal file
@ -0,0 +1,399 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class FastLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
linear = FastLinear(weight, bias)
|
||||
return linear
|
||||
|
||||
|
||||
class SuperLayer(nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
assert False
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
world_size = self.process_group.size()
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
if input.shape[0] == 1:
|
||||
world_out = input.new_empty(1, out_dim * world_size)
|
||||
local_out = input.new_empty(1, out_dim)
|
||||
gather_input = local_out
|
||||
else:
|
||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
|
||||
if input.shape[0] == 1:
|
||||
return world_out
|
||||
return world_out.T
|
||||
|
||||
output = super().forward(input)
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
block_size = num_embeddings // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = block_size
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
"""Additional 0 entry used for masking"""
|
||||
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||
# translate for [0, self.max_id - self.min_id[
|
||||
input = torch.where(
|
||||
(self.min_id > input) | (input >= self.max_id),
|
||||
self.null_idx,
|
||||
input - self.min_id,
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=base, device=inv_freq.device,
|
||||
scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=10000.0, device=inv_freq.device,
|
||||
scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim: 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (
|
||||
self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
except ImportError:
|
||||
pass
|
247
my_optims/optims/test_tp/test_llama_fa_tp.py
Normal file
247
my_optims/optims/test_tp/test_llama_fa_tp.py
Normal file
@ -0,0 +1,247 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- test_llama_fa.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/18-20:50
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
# Flash attention imports
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import transformers
|
||||
|
||||
from flash_llama_modeling import FlashLlamaForCausalLM
|
||||
from utils import initialize_torch_distributed
|
||||
from weights import Weights
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
self.device = device
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(self, blocks, max_blocks, needed_blocks_slots):
|
||||
"""
|
||||
blocks: 总共需要的blocks数量
|
||||
max_blocks: 最大的blocks数量大小
|
||||
needed_blocks_slots: 每个序列所需的blocks及其对应的序列长度
|
||||
"""
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= blocks
|
||||
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[: blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks: cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
cumulative_blocks += needed_blocks
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
return block_tables, block_tables_tensor.to(self.device), torch.concat(slots).to(self.device)
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
def generate(tokenizer, model, config, device, prompt, max_new_tokens=10):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
||||
def warmup():
|
||||
print("start warmup...")
|
||||
global CACHE_MANAGER
|
||||
blocks = 260
|
||||
CACHE_MANAGER = CacheManager(blocks,
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
torch.float16,
|
||||
device)
|
||||
input_length = 1024
|
||||
bs = 4
|
||||
warmup_inputs = {
|
||||
'input_ids': torch.arange(1, input_length + 1, dtype=torch.int64, device=device).repeat(bs),
|
||||
'position_ids': torch.arange(0, input_length, dtype=torch.int32, device=device).repeat(bs),
|
||||
'cu_seqlen_prefill': torch.tensor([i * input_length for i in range(bs + 1)], dtype=torch.int32,
|
||||
device=device),
|
||||
'block_tables': torch.arange(0, blocks, dtype=torch.int32, device=device).split(blocks // bs),
|
||||
'slots': torch.arange(0, 4144, dtype=torch.int32, device=device),
|
||||
'input_lengths': torch.tensor([input_length] * 4, dtype=torch.int32, device=device),
|
||||
'max_s': 1024,
|
||||
'lm_head_indices': None
|
||||
}
|
||||
model.forward(**warmup_inputs, kv_cache=CACHE_MANAGER.kv_cache)
|
||||
|
||||
del CACHE_MANAGER
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 预热
|
||||
warmup()
|
||||
|
||||
print("start speed test running")
|
||||
# 申请缓存空间
|
||||
global CACHE_MANAGER
|
||||
CACHE_MANAGER = CacheManager(100,
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
torch.float16,
|
||||
device)
|
||||
total_tokens = len(input_ids) + max_new_tokens - 1
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
needed_blocks_slots = [(needed_blocks, total_tokens)]
|
||||
_, block_tables_tensor, slots = CACHE_MANAGER.allocate(needed_blocks, needed_blocks, needed_blocks_slots)
|
||||
# forward循环
|
||||
loops = 10
|
||||
tpss = []
|
||||
for loop in range(loops):
|
||||
print(f"loop {loop}...")
|
||||
times = []
|
||||
new_tokens = []
|
||||
for step in range(max_new_tokens):
|
||||
if step == 0:
|
||||
# prefill step
|
||||
slot_indices = torch.arange(0, 0 + len(input_ids), dtype=torch.int64)
|
||||
inputs = {
|
||||
'input_ids': torch.tensor(input_ids, dtype=torch.int64, device=device),
|
||||
'position_ids': torch.arange(0, len(input_ids), dtype=torch.int32, device=device),
|
||||
'cu_seqlen_prefill': torch.tensor([0, len(input_ids)], dtype=torch.int32, device=device),
|
||||
'block_tables': block_tables_tensor,
|
||||
'slots': slots[slot_indices],
|
||||
'input_lengths': torch.tensor([len(input_ids)], dtype=torch.int32, device=device),
|
||||
'max_s': len(input_ids),
|
||||
'lm_head_indices': torch.tensor([0 + len(input_ids) - 1], dtype=torch.int32, device=device)
|
||||
}
|
||||
else:
|
||||
# incremental step
|
||||
current_length = len(input_ids) + step
|
||||
inputs = {
|
||||
'input_ids': new_tokens[-1],
|
||||
'position_ids': torch.tensor([current_length - 1], dtype=torch.int32, device=device),
|
||||
'cu_seqlen_prefill': None,
|
||||
'block_tables': block_tables_tensor,
|
||||
'slots': torch.tensor([current_length - 1], dtype=torch.int32, device=device),
|
||||
'input_lengths': torch.tensor([current_length], dtype=torch.int32, device=device),
|
||||
'max_s': current_length,
|
||||
'lm_head_indices': None
|
||||
}
|
||||
torch.cuda.synchronize()
|
||||
s_time = time.time()
|
||||
logits = model.forward(**inputs, kv_cache=CACHE_MANAGER.kv_cache)
|
||||
torch.cuda.synchronize()
|
||||
cost_time = time.time() - s_time
|
||||
next_token_id = logits.argmax(dim=-1)
|
||||
new_tokens.append(next_token_id)
|
||||
times.append(round(cost_time, 6))
|
||||
|
||||
if loop == 0:
|
||||
new_tokens = torch.concat(new_tokens)
|
||||
print(tokenizer.decode(new_tokens, skip_special_tokens=True))
|
||||
|
||||
elapsed_time = np.mean(times)
|
||||
tps = 1 / elapsed_time
|
||||
tpss.append(tps)
|
||||
print(times)
|
||||
print(f"total new tokens: {max_new_tokens}, cost time: {sum(times):.6f} s\n"
|
||||
f"time_per_token: {elapsed_time * 1000:.3f} ms, tps: {tps:.2f} tokens/s")
|
||||
print(f'mean tps: {np.mean(tpss):.2f} tokens/s')
|
||||
|
||||
|
||||
def main(model_path):
|
||||
# init env
|
||||
process_group, rank, world_size = initialize_torch_distributed()
|
||||
# step 0: 定义路径与属性
|
||||
model_path = Path(model_path)
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
config.quantize = None
|
||||
model_files = list(model_path.glob('*.safetensors'))
|
||||
|
||||
# step 1: 定义tokenizer与权重
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="left", truncation_side="left")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
weights = Weights(model_files, device, torch.float16, process_group=process_group)
|
||||
|
||||
# step2: 定义模型
|
||||
torch.distributed.barrier(group=process_group)
|
||||
model = FlashLlamaForCausalLM(config, weights).eval()
|
||||
torch.distributed.barrier(group=process_group)
|
||||
print(model)
|
||||
|
||||
# step3: 推理
|
||||
with torch.no_grad():
|
||||
prompt = "who are you?"
|
||||
generate(tokenizer, model, config, device, prompt, max_new_tokens=100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
CACHE_MANAGER: Optional[CacheManager] = None
|
||||
main('/code/models/llama-7b-hf')
|
100
my_optims/optims/test_tp/utils.py
Normal file
100
my_optims/optims/test_tp/utils.py
Normal file
@ -0,0 +1,100 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- utils.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/27-14:43
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
RANK = int(os.getenv("LOCAL_RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "2"))
|
||||
NCCL_PORT = int(os.getenv("NCCL_PORT", "29500"))
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
def __init__(self, rank, size):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
|
||||
def allreduce(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
||||
def allgather(self, inputs, local_tensor, **kwargs):
|
||||
assert (
|
||||
len(inputs[0]) == len(local_tensor) == 1
|
||||
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
||||
for input_ in inputs:
|
||||
input_[0].data = local_tensor[0].data
|
||||
return FakeBarrier()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
# Set the device id.
|
||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
if not torch.distributed.is_initialized():
|
||||
# Call the init process.
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://localhost:{NCCL_PORT}",
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
)
|
||||
logger.info(f"torch.distributed is initialized on rank {RANK} of {WORLD_SIZE}.")
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
||||
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.times = []
|
||||
self.time = None
|
||||
|
||||
def start(self):
|
||||
torch.cuda.synchronize()
|
||||
self.time = time.time()
|
||||
|
||||
def end(self):
|
||||
torch.cuda.synchronize()
|
||||
self.times.append(time.time() - self.time)
|
||||
|
||||
@property
|
||||
def elapsed(self):
|
||||
self.times.pop(0)
|
||||
return round(sum(self.times) / len(self.times) * 1000, 2)
|
336
my_optims/optims/test_tp/weights.py
Normal file
336
my_optims/optims/test_tp/weights.py
Normal file
@ -0,0 +1,336 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- weights.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/27-15:05
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from safetensors import safe_open, SafetensorError
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
filenames: List[Path],
|
||||
device,
|
||||
dtype,
|
||||
process_group,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.aliases = aliases
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename):
|
||||
if filename not in self._handles:
|
||||
f = safe_open(filename, framework="pytorch")
|
||||
self._handles[filename] = f
|
||||
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
for alias in aliases:
|
||||
filename = self.routing.get(alias, None)
|
||||
if filename is not None:
|
||||
return str(filename), alias
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
return str(filename), tensor_name
|
||||
|
||||
def _get_slice(self, tensor_name: str):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str, to_device=True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype != torch.int32:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
size = slice_.get_shape()[dim]
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def _get_qweight(self, name: str):
|
||||
slice_ = self._get_slice(name)
|
||||
total_size = slice_.get_shape()[1]
|
||||
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[:, start:stop]
|
||||
k = slice_[:, start + single_size:stop + single_size]
|
||||
v = slice_[:, start + 2 * single_size:stop + 2 * single_size]
|
||||
weight = torch.cat([q, k, v], dim=1)
|
||||
weight = weight.to(device=self.device)
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
"""
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[start:stop]
|
||||
k = slice_[start + single_size:stop + single_size]
|
||||
v = slice_[start + 2 * single_size:stop + 2 * single_size]
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
block_size = var.size()[dim] // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
if dim == 0:
|
||||
tensor = var[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = var[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[i // groupsize for i in range(g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
and not (g_idx == 0).all()
|
||||
):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
logger.warning(
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info("Using exllama kernels")
|
||||
|
||||
if use_exllama:
|
||||
if groupsize >= 0:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
else:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
|
||||
def _set_gptq_params(self, model_id):
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
except Exception:
|
||||
pass
|
94
my_optims/tgi_update/flash_llama.py
Normal file
94
my_optims/tgi_update/flash_llama.py
Normal file
@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
LlamaConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
import os
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
|
||||
if USE_CUSTOM_NCCL:
|
||||
from text_generation_server.utils.my_dist import initialize_mpi_distributed
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if USE_CUSTOM_NCCL:
|
||||
self.process_group, rank, world_size, COMM = initialize_mpi_distributed()
|
||||
else:
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
try:
|
||||
raise
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = LlamaConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
if USE_CUSTOM_NCCL:
|
||||
COMM.barrier()
|
||||
else:
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
|
||||
if USE_CUSTOM_NCCL:
|
||||
COMM.barrier()
|
||||
else:
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
513
my_optims/tgi_update/flash_llama_modeling.py
Normal file
513
my_optims/tgi_update/flash_llama_modeling.py
Normal file
@ -0,0 +1,513 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
class LlamaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
import os
|
||||
if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1:
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
else:
|
||||
from torch.nn import functional as F
|
||||
from loguru import logger
|
||||
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
|
||||
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
|
||||
padding_idx=config.pad_token_id)
|
||||
logger.info("Disabled embedding tensor parallel! ")
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
815
my_optims/tgi_update/layers.py
Normal file
815
my_optims/tgi_update/layers.py
Normal file
@ -0,0 +1,815 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||
except ImportError:
|
||||
HAS_AWQ = False
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||
|
||||
HAS_EXLLAMA = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from typing import Optional
|
||||
|
||||
HAS_EETQ = False
|
||||
try:
|
||||
from EETQ import quant_weights, w8_a16_gemm
|
||||
|
||||
HAS_EETQ = True
|
||||
except ImportError:
|
||||
pass
|
||||
import my_custom_comm
|
||||
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
|
||||
USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1"))
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
with init_empty_weights():
|
||||
ln = cls(weight.shape, eps=eps)
|
||||
|
||||
ln.weight = nn.Parameter(weight)
|
||||
ln.bias = nn.Parameter(bias)
|
||||
return ln
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
with init_empty_weights():
|
||||
ln = cls(weight.shape, eps=eps)
|
||||
|
||||
ln.weight = nn.Parameter(weight)
|
||||
ln.bias = None
|
||||
return ln
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
with init_empty_weights():
|
||||
conv2d = cls(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
conv2d.weight = nn.Parameter(weight)
|
||||
conv2d.bias = nn.Parameter(bias)
|
||||
return conv2d
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_conv2d_no_bias(
|
||||
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||
):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
with init_empty_weights():
|
||||
conv2d = cls(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
conv2d.weight = nn.Parameter(weight)
|
||||
conv2d.bias = None
|
||||
return conv2d
|
||||
|
||||
|
||||
torch.nn.Conv2d.load = load_conv2d
|
||||
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
|
||||
torch.nn.LayerNorm.load = load_layer_norm
|
||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||
|
||||
|
||||
class FastLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class EETQLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
|
||||
self.weight = weight.cuda(device)
|
||||
self.scale = scale.cuda(device)
|
||||
self.bias = bias.cuda(device) if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = w8_a16_gemm(input, self.weight, self.scale)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
return output
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
not memory_efficient_backward
|
||||
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
# Necessary for stacked layers
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
weight.data,
|
||||
has_fp16_weights=has_fp16_weights,
|
||||
requires_grad=has_fp16_weights,
|
||||
)
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
self.weight = Params4bit(
|
||||
weight.data,
|
||||
requires_grad=False,
|
||||
compress_statistics=True,
|
||||
quant_type=quant_type,
|
||||
)
|
||||
self.compute_dtype = None
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, "quant_state", None) is None:
|
||||
print(
|
||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||
)
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(
|
||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||
)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def warn_deprecate_bnb():
|
||||
logger.warning(
|
||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
||||
)
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
elif quantize == "eetq":
|
||||
if HAS_EETQ:
|
||||
linear = EETQLinear(weight, bias)
|
||||
else:
|
||||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
elif quantize == "bitsandbytes":
|
||||
warn_deprecate_bnb()
|
||||
linear = Linear8bitLt(
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="fp4",
|
||||
)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx,
|
||||
bias,
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
try:
|
||||
qweight, qzeros, scales, _, bits, groupsize, _ = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
linear = WQLinear(
|
||||
w_bit=bits,
|
||||
group_size=groupsize,
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
bias=bias is not None,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
||||
|
||||
class SuperLayer(nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
logger.info("Disabled lm head parallel! ")
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq"]:
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
if input.shape[0] == 1:
|
||||
world_out = input.new_empty(1, out_dim * world_size)
|
||||
local_out = input.new_empty(1, out_dim)
|
||||
gather_input = local_out
|
||||
else:
|
||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
|
||||
if input.shape[0] == 1:
|
||||
return world_out
|
||||
return world_out.T
|
||||
|
||||
output = super().forward(input)
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1:
|
||||
if USE_CUSTOM_NCCL:
|
||||
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
block_size = num_embeddings // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = block_size
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
"""Additional 0 entry used for masking"""
|
||||
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||
# translate for [0, self.max_id - self.min_id[
|
||||
input = torch.where(
|
||||
(self.min_id > input) | (input >= self.max_id),
|
||||
self.null_idx,
|
||||
input - self.min_id,
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
if USE_CUSTOM_NCCL:
|
||||
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
import dropout_layer_norm
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {
|
||||
"type": os.environ["ROPE_SCALING"],
|
||||
"factor": float(os.environ["ROPE_FACTOR"]),
|
||||
}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=base,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif rope_scaling["type"] == "yarn":
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
attn_factor=1,
|
||||
beta_fast=32,
|
||||
beta_slow=1
|
||||
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif rope_scaling["type"] == "yarn":
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
attn_factor=1,
|
||||
beta_fast=32,
|
||||
beta_slow=1
|
||||
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * (
|
||||
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||
- (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(
|
||||
self.dim, newbase, self.inv_freq.device
|
||||
)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
import math
|
||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
||||
low = math.floor(find_correction_dim(
|
||||
low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(find_correction_dim(
|
||||
high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim-1) # Clamp values just in case
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
def get_mscale(scale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
inv_freq_extrapolation = _create_inv_freq(
|
||||
self.dim, self.base, self.inv_freq.device
|
||||
)
|
||||
freqs = 1.0 / inv_freq_extrapolation
|
||||
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
|
||||
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
|
||||
self.inv_freq = inv_freq
|
||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||
|
||||
except ImportError:
|
||||
pass
|
1268
my_optims/tgi_update/main.rs
Normal file
1268
my_optims/tgi_update/main.rs
Normal file
File diff suppressed because it is too large
Load Diff
47
my_optims/tgi_update/my_dist.py
Normal file
47
my_optims/tgi_update/my_dist.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from mpi4py import MPI
|
||||
import my_custom_comm
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
|
||||
|
||||
# CUDA memory fraction
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
class MyCommGroup:
|
||||
def __init__(self, rank, size, tp_comm, pp_comm):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
self.tp_comm = tp_comm
|
||||
self.pp_comm = pp_comm
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
def initialize_mpi_distributed():
|
||||
assert torch.cuda.is_available()
|
||||
# mpi initialize
|
||||
COMM = MPI.COMM_WORLD
|
||||
assert COMM.Get_size() == WORLD_SIZE, f"{COMM.Get_size()},{WORLD_SIZE}"
|
||||
|
||||
# Set the device id.
|
||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
|
||||
# nccl initialize
|
||||
tp_comm, pp_comm = my_custom_comm.init_nccl(WORLD_SIZE, 1)
|
||||
process_group = MyCommGroup(RANK, WORLD_SIZE, tp_comm, pp_comm)
|
||||
|
||||
logger.warning("custom mpi and nccl is already initialized.")
|
||||
return process_group, RANK, WORLD_SIZE, COMM
|
217
my_optims/tgi_update/server.py
Normal file
217
my_optims/tgi_update/server.py
Normal file
@ -0,0 +1,217 @@
|
||||
import asyncio
|
||||
import os
|
||||
import torch
|
||||
|
||||
from grpc import aio
|
||||
from loguru import logger
|
||||
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
self.cache = cache
|
||||
self.model = model
|
||||
self.server_urls = server_urls
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
if model.device.type == "cuda":
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
||||
async def Health(self, request, context):
|
||||
if self.model.device.type == "cuda":
|
||||
torch.zeros((2, 2)).cuda()
|
||||
return generate_pb2.HealthResponse()
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||
|
||||
async def ClearCache(self, request, context):
|
||||
if request.HasField("id"):
|
||||
self.cache.delete(request.id)
|
||||
else:
|
||||
self.cache.clear()
|
||||
return generate_pb2.ClearCacheResponse()
|
||||
|
||||
async def FilterBatch(self, request, context):
|
||||
batch = self.cache.pop(request.batch_id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||
filtered_batch = batch.filter(request.request_ids)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
max_supported_total_tokens = self.model.warmup(batch)
|
||||
|
||||
return generate_pb2.WarmupResponse(
|
||||
max_supported_total_tokens=max_supported_total_tokens
|
||||
)
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
async def Decode(self, request, context):
|
||||
if len(request.batches) == 0:
|
||||
raise ValueError("Must provide at least one batch")
|
||||
|
||||
batches = []
|
||||
for batch_pb in request.batches:
|
||||
batch = self.cache.pop(batch_pb.id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) == 0:
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if len(batches) > 1:
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
logger.info(os.environ)
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||
]
|
||||
local_url = server_urls[int(os.environ["RANK"])]
|
||||
else:
|
||||
local_url = unix_socket_template.format(uds_path, 0)
|
||||
server_urls = [local_url]
|
||||
|
||||
if int(os.environ.get("USE_CUSTOM_NCCL", 0)):
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"]))
|
||||
]
|
||||
local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])]
|
||||
|
||||
try:
|
||||
model = get_model(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.gptq.exllama import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
set_device(model.device)
|
||||
create_exllama_buffers()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
UDSOpenTelemetryAioServerInterceptor(),
|
||||
]
|
||||
)
|
||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||
TextGenerationService(model, Cache(), server_urls), server
|
||||
)
|
||||
SERVICE_NAMES = (
|
||||
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
|
||||
reflection.SERVICE_NAME,
|
||||
)
|
||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||
server.add_insecure_port(local_url)
|
||||
|
||||
await server.start()
|
||||
|
||||
logger.info("Server started at {}".format(local_url))
|
||||
|
||||
try:
|
||||
await server.wait_for_termination()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Signal received. Shutting down")
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
)
|
467
my_optims/tp_optims/flash_llama_modeling.py
Normal file
467
my_optims/tp_optims/flash_llama_modeling.py
Normal file
@ -0,0 +1,467 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
import flash_attn_2_cuda
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
# vllm imports
|
||||
from vllm import attention_ops, cache_ops
|
||||
from torch.nn import functional as F
|
||||
|
||||
from layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
if config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize != "gptq":
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** -0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
-1,
|
||||
0,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
attention_ops.paged_attention_v1(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
# self.embed_tokens = TensorParallelEmbedding(
|
||||
# prefix="model.embed_tokens", weights=weights
|
||||
# )
|
||||
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
|
||||
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
|
||||
padding_idx=config.pad_token_id)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
# for layer_id in range(1)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
402
my_optims/tp_optims/layers.py
Normal file
402
my_optims/tp_optims/layers.py
Normal file
@ -0,0 +1,402 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
# import torch.distributed
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import my_custom_comm
|
||||
|
||||
class FastLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
linear = FastLinear(weight, bias)
|
||||
return linear
|
||||
|
||||
|
||||
class SuperLayer(nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
assert False
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
# world_size = self.process_group.size()
|
||||
# if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
# out_dim = self.linear.weight.shape[0]
|
||||
|
||||
# if input.shape[0] == 1:
|
||||
# world_out = input.new_empty(1, out_dim * world_size)
|
||||
# local_out = input.new_empty(1, out_dim)
|
||||
# gather_input = local_out
|
||||
# else:
|
||||
# world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||
# gather_input = input.new_empty(out_dim, input.shape[0])
|
||||
# local_out = gather_input.T
|
||||
|
||||
# torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
# torch.distributed.all_gather_into_tensor(
|
||||
# world_out, gather_input, group=self.process_group
|
||||
# )
|
||||
|
||||
# if input.shape[0] == 1:
|
||||
# return world_out
|
||||
# return world_out.T
|
||||
|
||||
# output = super().forward(input)
|
||||
# world_output = [
|
||||
# torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
# ]
|
||||
# torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
# world_output = torch.cat(world_output, dim=-1)
|
||||
# return world_output
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1:
|
||||
# torch.distributed.all_reduce(out, group=self.process_group)
|
||||
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
block_size = num_embeddings // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = block_size
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
"""Additional 0 entry used for masking"""
|
||||
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||
# translate for [0, self.max_id - self.min_id[
|
||||
input = torch.where(
|
||||
(self.min_id > input) | (input >= self.max_id),
|
||||
self.null_idx,
|
||||
input - self.min_id,
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
# torch.distributed.all_reduce(out, group=self.process_group)
|
||||
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=base, device=inv_freq.device,
|
||||
scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=10000.0, device=inv_freq.device,
|
||||
scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim: 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (
|
||||
self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
except ImportError:
|
||||
pass
|
277
my_optims/tp_optims/test_llama_fa_tp.py
Normal file
277
my_optims/tp_optims/test_llama_fa_tp.py
Normal file
@ -0,0 +1,277 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- test_llama_fa.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/18-20:50
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import math
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from flash_llama_modeling import FlashLlamaForCausalLM
|
||||
from weights import Weights
|
||||
from mpi4py import MPI
|
||||
import my_custom_comm
|
||||
|
||||
COMM = None
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
def __init__(self, rank, size, tp_comm, pp_comm):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
self.tp_comm = tp_comm
|
||||
self.pp_comm = pp_comm
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
self.device = device
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(self, blocks, max_blocks, needed_blocks_slots):
|
||||
"""
|
||||
blocks: 总共需要的blocks数量
|
||||
max_blocks: 最大的blocks数量大小
|
||||
needed_blocks_slots: 每个序列所需的blocks及其对应的序列长度
|
||||
"""
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= blocks
|
||||
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[: blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks: cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
cumulative_blocks += needed_blocks
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
return block_tables, block_tables_tensor.to(self.device), torch.concat(slots).to(self.device)
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
def generate(tokenizer, model, config, device, prompt, max_new_tokens=10):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
||||
def warmup():
|
||||
print("start warmup...")
|
||||
global CACHE_MANAGER
|
||||
blocks = 260
|
||||
CACHE_MANAGER = CacheManager(blocks,
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
torch.float16,
|
||||
device)
|
||||
input_length = 1024
|
||||
bs = 4
|
||||
warmup_inputs = {
|
||||
'input_ids': torch.arange(1, input_length + 1, dtype=torch.int64, device=device).repeat(bs),
|
||||
'position_ids': torch.arange(0, input_length, dtype=torch.int32, device=device).repeat(bs),
|
||||
'cu_seqlen_prefill': torch.tensor([i * input_length for i in range(bs + 1)], dtype=torch.int32,
|
||||
device=device),
|
||||
'block_tables': torch.arange(0, blocks, dtype=torch.int32, device=device).split(blocks // bs),
|
||||
'slots': torch.arange(0, 4144, dtype=torch.int32, device=device),
|
||||
'input_lengths': torch.tensor([input_length] * 4, dtype=torch.int32, device=device),
|
||||
'max_s': 1024,
|
||||
'lm_head_indices': None
|
||||
}
|
||||
model.forward(**warmup_inputs, kv_cache=CACHE_MANAGER.kv_cache)
|
||||
|
||||
del CACHE_MANAGER
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 预热
|
||||
warmup()
|
||||
|
||||
print("start speed test running")
|
||||
# 申请缓存空间
|
||||
global CACHE_MANAGER
|
||||
CACHE_MANAGER = CacheManager(100,
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
torch.float16,
|
||||
device)
|
||||
total_tokens = len(input_ids) + max_new_tokens - 1
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
needed_blocks_slots = [(needed_blocks, total_tokens)]
|
||||
_, block_tables_tensor, slots = CACHE_MANAGER.allocate(needed_blocks, needed_blocks, needed_blocks_slots)
|
||||
# forward循环
|
||||
tpss = []
|
||||
loops = 10
|
||||
for loop in range(loops):
|
||||
print(f"loop {loop}...")
|
||||
times = []
|
||||
new_tokens = []
|
||||
for step in range(max_new_tokens):
|
||||
if step == 0:
|
||||
# prefill step
|
||||
slot_indices = torch.arange(0, 0 + len(input_ids), dtype=torch.int64)
|
||||
inputs = {
|
||||
'input_ids': torch.tensor(input_ids, dtype=torch.int64, device=device),
|
||||
'position_ids': torch.arange(0, len(input_ids), dtype=torch.int32, device=device),
|
||||
'cu_seqlen_prefill': torch.tensor([0, len(input_ids)], dtype=torch.int32, device=device),
|
||||
'block_tables': block_tables_tensor,
|
||||
'slots': slots[slot_indices],
|
||||
'input_lengths': torch.tensor([len(input_ids)], dtype=torch.int32, device=device),
|
||||
'max_s': len(input_ids),
|
||||
'lm_head_indices': torch.tensor([0 + len(input_ids) - 1], dtype=torch.int32, device=device)
|
||||
}
|
||||
else:
|
||||
# incremental step
|
||||
current_length = len(input_ids) + step
|
||||
inputs = {
|
||||
'input_ids': new_tokens[-1],
|
||||
'position_ids': torch.tensor([current_length - 1], dtype=torch.int32, device=device),
|
||||
'cu_seqlen_prefill': None,
|
||||
'block_tables': block_tables_tensor,
|
||||
'slots': torch.tensor([current_length - 1], dtype=torch.int32, device=device),
|
||||
'input_lengths': torch.tensor([current_length], dtype=torch.int32, device=device),
|
||||
'max_s': current_length,
|
||||
'lm_head_indices': None
|
||||
}
|
||||
torch.cuda.synchronize()
|
||||
s_time = time.time()
|
||||
logits = model.forward(**inputs, kv_cache=CACHE_MANAGER.kv_cache)
|
||||
torch.cuda.synchronize()
|
||||
cost_time = time.time() - s_time
|
||||
next_token_id = logits.argmax(dim=-1)
|
||||
new_tokens.append(next_token_id)
|
||||
times.append(round(cost_time, 6))
|
||||
|
||||
if loop == 0:
|
||||
new_tokens = torch.concat(new_tokens)
|
||||
print(tokenizer.decode(new_tokens, skip_special_tokens=True))
|
||||
|
||||
elapsed_time = np.mean(times)
|
||||
tps = 1 / elapsed_time
|
||||
tpss.append(tps)
|
||||
print(times)
|
||||
print(f"total new tokens: {max_new_tokens}, cost time: {sum(times):.6f} s\n"
|
||||
f"time_per_token: {elapsed_time * 1000:.3f} ms, tps: {tps:.2f} tokens/s")
|
||||
print(f'mean tps: {np.mean(tpss):.2f} tokens/s')
|
||||
|
||||
def init_dist_env():
|
||||
global COMM
|
||||
COMM = MPI.COMM_WORLD
|
||||
rank = COMM.Get_rank()
|
||||
world_size = COMM.Get_size()
|
||||
tp_ptr, pp_ptr = my_custom_comm.init_nccl(2, 1)
|
||||
device = rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(1., device)
|
||||
process_group = FakeGroup(rank, world_size, tp_ptr, pp_ptr)
|
||||
return process_group, rank, world_size
|
||||
|
||||
|
||||
def main(model_path):
|
||||
# init env
|
||||
process_group, rank, world_size = init_dist_env()
|
||||
print(f'{rank=},{world_size=},{process_group.__dict__=}')
|
||||
# step 0: 定义路径与属性
|
||||
model_path = Path(model_path)
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
config.quantize = None
|
||||
model_files = list(model_path.glob('*.safetensors'))
|
||||
|
||||
# step 1: 定义tokenizer与权重
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="left", truncation_side="left")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
weights = Weights(model_files, device, torch.float16, process_group=process_group)
|
||||
|
||||
# step2: 定义模型
|
||||
COMM.barrier()
|
||||
model = FlashLlamaForCausalLM(config, weights).eval()
|
||||
COMM.barrier()
|
||||
print(model)
|
||||
|
||||
# step3: 推理
|
||||
with torch.no_grad():
|
||||
prompt = "who are you?"
|
||||
generate(tokenizer, model, config, device, prompt, max_new_tokens=100)
|
||||
my_custom_comm.finalize_nccl(process_group.tp_comm, process_group.pp_comm)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
CACHE_MANAGER: Optional[CacheManager] = None
|
||||
main('/code/models/llama-7b-hf')
|
100
my_optims/tp_optims/utils.py
Normal file
100
my_optims/tp_optims/utils.py
Normal file
@ -0,0 +1,100 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- utils.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/27-14:43
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
RANK = int(os.getenv("LOCAL_RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "2"))
|
||||
NCCL_PORT = int(os.getenv("NCCL_PORT", "29500"))
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
def __init__(self, rank, size):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
|
||||
def allreduce(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
||||
def allgather(self, inputs, local_tensor, **kwargs):
|
||||
assert (
|
||||
len(inputs[0]) == len(local_tensor) == 1
|
||||
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
||||
for input_ in inputs:
|
||||
input_[0].data = local_tensor[0].data
|
||||
return FakeBarrier()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
# Set the device id.
|
||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
if not torch.distributed.is_initialized():
|
||||
# Call the init process.
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://localhost:{NCCL_PORT}",
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
)
|
||||
logger.info(f"torch.distributed is initialized on rank {RANK} of {WORLD_SIZE}.")
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
||||
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.times = []
|
||||
self.time = None
|
||||
|
||||
def start(self):
|
||||
torch.cuda.synchronize()
|
||||
self.time = time.time()
|
||||
|
||||
def end(self):
|
||||
torch.cuda.synchronize()
|
||||
self.times.append(time.time() - self.time)
|
||||
|
||||
@property
|
||||
def elapsed(self):
|
||||
self.times.pop(0)
|
||||
return round(sum(self.times) / len(self.times) * 1000, 2)
|
336
my_optims/tp_optims/weights.py
Normal file
336
my_optims/tp_optims/weights.py
Normal file
@ -0,0 +1,336 @@
|
||||
# encoding:utf-8
|
||||
# -------------------------------------------#
|
||||
# Filename: optims -- weights.py
|
||||
#
|
||||
# Description:
|
||||
# Version: 1.0
|
||||
# Created: 2023/9/27-15:05
|
||||
# Last modified by:
|
||||
# Author: 'zhaohuayang@myhexin.com'
|
||||
# Company: 同花顺网络信息股份有限公司
|
||||
# -------------------------------------------#
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
from safetensors import safe_open, SafetensorError
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
filenames: List[Path],
|
||||
device,
|
||||
dtype,
|
||||
process_group,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.aliases = aliases
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename):
|
||||
if filename not in self._handles:
|
||||
f = safe_open(filename, framework="pytorch")
|
||||
self._handles[filename] = f
|
||||
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
for alias in aliases:
|
||||
filename = self.routing.get(alias, None)
|
||||
if filename is not None:
|
||||
return str(filename), alias
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
return str(filename), tensor_name
|
||||
|
||||
def _get_slice(self, tensor_name: str):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str, to_device=True):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype not in [torch.int32, torch.int64]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32
|
||||
if tensor.dtype != torch.int32:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
size = slice_.get_shape()[dim]
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def _get_qweight(self, name: str):
|
||||
slice_ = self._get_slice(name)
|
||||
total_size = slice_.get_shape()[1]
|
||||
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[:, start:stop]
|
||||
k = slice_[:, start + single_size:stop + single_size]
|
||||
v = slice_[:, start + 2 * single_size:stop + 2 * single_size]
|
||||
weight = torch.cat([q, k, v], dim=1)
|
||||
weight = weight.to(device=self.device)
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
"""
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
||||
single_size = total_size // 3
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q = slice_[start:stop]
|
||||
k = slice_[start + single_size:stop + single_size]
|
||||
v = slice_[start + 2 * single_size:stop + 2 * single_size]
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
return weight
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
block_size = var.size()[dim] // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
if dim == 0:
|
||||
tensor = var[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = var[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[i // groupsize for i in range(g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
and not (g_idx == 0).all()
|
||||
):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
logger.warning(
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info("Using exllama kernels")
|
||||
|
||||
if use_exllama:
|
||||
if groupsize >= 0:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
else:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
|
||||
def _set_gptq_params(self, model_id):
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
except Exception:
|
||||
pass
|
47
server/text_generation_server/utils/my_dist.py
Normal file
47
server/text_generation_server/utils/my_dist.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from mpi4py import MPI
|
||||
import my_custom_comm
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
|
||||
|
||||
# CUDA memory fraction
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
class MyCommGroup:
|
||||
def __init__(self, rank, size, tp_comm, pp_comm):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
self.tp_comm = tp_comm
|
||||
self.pp_comm = pp_comm
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
def initialize_mpi_distributed():
|
||||
assert torch.cuda.is_available()
|
||||
# mpi initialize
|
||||
COMM = MPI.COMM_WORLD
|
||||
assert COMM.Get_size() == WORLD_SIZE, f"{COMM.Get_size()},{WORLD_SIZE}"
|
||||
|
||||
# Set the device id.
|
||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
|
||||
# nccl initialize
|
||||
tp_comm, pp_comm = my_custom_comm.init_nccl(WORLD_SIZE, 1)
|
||||
process_group = MyCommGroup(RANK, WORLD_SIZE, tp_comm, pp_comm)
|
||||
|
||||
logger.warning("custom mpi and nccl is already initialized.")
|
||||
return process_group, RANK, WORLD_SIZE, COMM
|
Loading…
Reference in New Issue
Block a user