mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Just trying to get the integration tests to pass. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
295 lines
11 KiB
Plaintext
295 lines
11 KiB
Plaintext
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
|
|
|
#ifndef _matrix_cuh
|
|
#define _matrix_cuh
|
|
|
|
#include <cuda_runtime.h>
|
|
#include <cuda_fp16.h>
|
|
|
|
class MatrixView_half
|
|
{
|
|
public:
|
|
const half* data;
|
|
const int height;
|
|
const int width;
|
|
|
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
|
: data(data), height(height), width(width)
|
|
{ }
|
|
|
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
|
};
|
|
|
|
class MatrixView_half_rw
|
|
{
|
|
public:
|
|
half* data;
|
|
const int height;
|
|
const int width;
|
|
|
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
|
: data(data), height(height), width(width)
|
|
{ }
|
|
|
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
|
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
|
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
|
};
|
|
|
|
class MatrixView_q4_row
|
|
{
|
|
public:
|
|
const uint32_t* data;
|
|
const int height;
|
|
const int width;
|
|
|
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
|
: data(data), height(height), width(width)
|
|
{ }
|
|
|
|
__device__ __forceinline__ int item(int row, int column) const
|
|
{
|
|
int shift = (column & 0x07) * 4;
|
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
|
}
|
|
};
|
|
|
|
class MatrixView_q4_column
|
|
{
|
|
public:
|
|
const uint32_t* data;
|
|
const int height;
|
|
const int width;
|
|
|
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
|
: data(data), height(height), width(width)
|
|
{ }
|
|
|
|
__device__ __forceinline__ int item(int row, int column) const
|
|
{
|
|
int shift = (row & 0x07) * 4;
|
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
|
}
|
|
|
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
|
};
|
|
|
|
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
|
|
|
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
|
|
|
__device__ __forceinline__ half2 dot_product_8
|
|
(
|
|
const half2 acc,
|
|
MatrixView_half& h_,
|
|
const int h_row,
|
|
const int h_column, // divisible by 8
|
|
MatrixView_q4_column& v_,
|
|
const int v_row, // divisible by 8
|
|
const int v_column,
|
|
const half2 v_scale_2,
|
|
const uint32_t v_zero, // + 1 (!!)
|
|
const int count
|
|
)
|
|
{
|
|
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
|
half2 result = acc;
|
|
|
|
for (int i = 0; i < count; i++)
|
|
{
|
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
|
|
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
|
|
|
half2 v_01 = __halves2half2(v_0, v_1);
|
|
half2 v_23 = __halves2half2(v_2, v_3);
|
|
half2 v_45 = __halves2half2(v_4, v_5);
|
|
half2 v_67 = __halves2half2(v_6, v_7);
|
|
|
|
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
|
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
|
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
|
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
|
|
|
half2 tmp = __hmul2(*h_ptr++, v_01);
|
|
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
|
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
|
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
|
result = __hfma2(v_scale_2, tmp, result);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
__device__ __forceinline__ half dot_product_8_h
|
|
(
|
|
const half acc,
|
|
MatrixView_half& h_,
|
|
const int h_row,
|
|
const int h_column, // divisible by 8
|
|
MatrixView_q4_column& v_,
|
|
const int v_row, // divisible by 8
|
|
const int v_column,
|
|
const half v_scale,
|
|
const uint32_t v_zero, // + 1 (!!)
|
|
const int count
|
|
)
|
|
{
|
|
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
|
half result = acc;
|
|
|
|
for (int i = 0; i < count; i++)
|
|
{
|
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
|
|
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
|
|
|
half tmp = __hmul(*h_ptr++, v_0);
|
|
tmp = __hfma(*h_ptr++, v_1, tmp);
|
|
tmp = __hfma(*h_ptr++, v_2, tmp);
|
|
tmp = __hfma(*h_ptr++, v_3, tmp);
|
|
tmp = __hfma(*h_ptr++, v_4, tmp);
|
|
tmp = __hfma(*h_ptr++, v_5, tmp);
|
|
tmp = __hfma(*h_ptr++, v_6, tmp);
|
|
tmp = __hfma(*h_ptr++, v_7, tmp);
|
|
result = __hfma(v_scale, tmp, result);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
|
|
|
__device__ __forceinline__ half2 dot_product_8_x_map
|
|
(
|
|
const half2 acc,
|
|
MatrixView_half& h_,
|
|
const int h_row,
|
|
const int h_column, // divisible by 8
|
|
MatrixView_q4_column& v_,
|
|
const int v_row, // divisible by 8
|
|
const int v_column,
|
|
const half2 v_scale_2,
|
|
const uint32_t v_zero, // + 1 (!!)
|
|
const int count,
|
|
const uint32_t* x_map
|
|
)
|
|
{
|
|
const half* h_ptr = h_.item_ptr(h_row, 0);
|
|
const uint32_t* x_map_ptr = x_map + h_column;
|
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
|
half2 result = acc;
|
|
|
|
for (int i = 0; i < count; i++)
|
|
{
|
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
|
|
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
|
|
|
half2 v_01 = __halves2half2(v_0, v_1);
|
|
half2 v_23 = __halves2half2(v_2, v_3);
|
|
half2 v_45 = __halves2half2(v_4, v_5);
|
|
half2 v_67 = __halves2half2(v_6, v_7);
|
|
|
|
half h_0 = h_ptr[*x_map_ptr++];
|
|
half h_1 = h_ptr[*x_map_ptr++];
|
|
half h_2 = h_ptr[*x_map_ptr++];
|
|
half h_3 = h_ptr[*x_map_ptr++];
|
|
half h_4 = h_ptr[*x_map_ptr++];
|
|
half h_5 = h_ptr[*x_map_ptr++];
|
|
half h_6 = h_ptr[*x_map_ptr++];
|
|
half h_7 = h_ptr[*x_map_ptr++];
|
|
|
|
half2 h_01 = __halves2half2(h_0, h_1);
|
|
half2 h_23 = __halves2half2(h_2, h_3);
|
|
half2 h_45 = __halves2half2(h_4, h_5);
|
|
half2 h_67 = __halves2half2(h_6, h_7);
|
|
|
|
half2 tmp = __hmul2(h_01, v_01);
|
|
tmp = __hfma2(h_23, v_23, tmp);
|
|
tmp = __hfma2(h_45, v_45, tmp);
|
|
tmp = __hfma2(h_67, v_67, tmp);
|
|
result = __hfma2(v_scale_2, tmp, result);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
__device__ __forceinline__ half dot_product_8_x_map_h
|
|
(
|
|
const half acc,
|
|
MatrixView_half& h_,
|
|
const int h_row,
|
|
const int h_column, // divisible by 8
|
|
MatrixView_q4_column& v_,
|
|
const int v_row, // divisible by 8
|
|
const int v_column,
|
|
const half v_scale,
|
|
const uint32_t v_zero, // + 1 (!!)
|
|
const int count,
|
|
const uint32_t* x_map
|
|
)
|
|
{
|
|
const half* h_ptr = h_.item_ptr(h_row, 0);
|
|
const uint32_t* x_map_ptr = x_map + h_column;
|
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
|
half result = acc;
|
|
|
|
for (int i = 0; i < count; i++)
|
|
{
|
|
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
|
|
|
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
|
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
|
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
|
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
|
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
|
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
|
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
|
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
|
|
|
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
|
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
|
result = __hfma(v_scale, tmp, result);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
#endif
|