Add GPT-2 with flash attention

This change adds `FlashGPT2ForCausalLM` and wires it up. The model
itself is pretty straightforward, the main difference from other
models is that it uses trained position embeddings and that all
weight matrices are transposed compared to other models (due to
the use of Conv1D in the upstream model).
This commit is contained in:
Daniël de Kok 2024-05-10 15:54:18 +00:00
parent e3d765645a
commit 8acd126710
8 changed files with 1098 additions and 1 deletions

View File

@ -9,6 +9,7 @@ The following models are optimized and can be served with TGI, which uses custom
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [GPT-2](https://huggingface.co/openai-community/gpt2)
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Llama](https://github.com/facebookresearch/llama)
- [OPT](https://huggingface.co/facebook/opt-66b)

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1835938,
"text": " is"
},
{
"id": 2769,
"logprob": -9.171875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6425781,
"text": " learning"
},
{
"id": 30,
"logprob": -0.7314453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.68603516,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.005393982,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.31079102,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08300781,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.58984375,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.953125,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0957031,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8095703,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9375,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
}

View File

@ -0,0 +1,398 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1835938,
"text": " is"
},
{
"id": 2769,
"logprob": -9.171875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6425781,
"text": " learning"
},
{
"id": 30,
"logprob": -0.7314453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.68603516,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.005672455,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3251953,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08294678,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5854492,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9423828,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0800781,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8369141,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0683594,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9711914,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
}
]

View File

@ -0,0 +1,44 @@
import pytest
@pytest.fixture(scope="module")
def flash_gpt2_handle(launcher):
with launcher("openai-community/gpt2", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gpt2(flash_gpt2_handle):
await flash_gpt2_handle.health(300)
return flash_gpt2_handle.client
@pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate(
"What is deep learning?",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load(
flash_gpt2,
"What is deep learning?",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert all(
[text == generated_texts[0] for text in generated_texts]
), generated_texts
assert responses == response_snapshot

View File

@ -132,6 +132,7 @@ pub enum Config {
Santacoder,
Bloom,
Mpt,
Gpt2,
GptNeox,
Phi,
#[serde(rename = "phi-msft")]

View File

@ -51,6 +51,7 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
@ -83,6 +84,7 @@ except ImportError as e:
HAS_FLASH_ATTN_V2_CUDA = False
if FLASH_ATTENTION:
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
@ -325,7 +327,27 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt2":
if FLASH_ATTENTION:
return FlashGPT2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(

View File

@ -0,0 +1,454 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
def load_qkv(config, prefix: str, weights, head_size, num_heads):
if config.quantize == "gptq":
return _load_qkv_gptq(
config,
prefix,
weights,
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
def _load_qkv_gptq(config, prefix: str, weights):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
# Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
single_size = total_size // 3
assert single_size % world_size == 0
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
tensors.append(tensor)
bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
"""Load QKV from a single, transposed matrix."""
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
total_size = shape[1]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
world_size = weights.process_group.size()
single_size = total_size // 3
assert single_size % world_size == 0
rank = weights.process_group.rank()
# Weights
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[:, start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=1).T
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
single_size = total_size // 3
block_size = single_size // world_size
assert single_size % world_size == 0
start = rank * block_size
stop = (rank + 1) * block_size
b = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
b.append(tensor)
bias = torch.cat(b, dim=0)
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
assert list(bias.shape) == [
3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=1
)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashGPT2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = load_qkv(
config,
prefix=prefix,
weights=weights,
head_size=self.head_size,
num_heads=self.num_heads,
)
self.o_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
intermediate_size = (
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
)
self.intermediate_size = intermediate_size // weights.process_group.size()
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_2",
weights=weights,
eps=config.layer_norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
return residual + mlp_output, residual
class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGPT2Layer(
prefix=(
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
),
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = nn.LayerNorm.load(
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=("wte" if not prefix else f"{prefix}.wte"),
weights=weights,
)
self.embed_positions = TensorParallelEmbedding(
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
weights=weights,
)
self.model = FlashGPT2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="wte" if not prefix else f"{prefix}.wte",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
inputs_embeds = token_embeds + position_embeds
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,78 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)