From 27a5d6b5f90dda4cc26fe9069be5df4e88b17fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 15 May 2024 13:31:22 +0200 Subject: [PATCH] Add GPT-2 with flash attention (#1889) # What does this PR do? This change adds `FlashGPT2ForCausalLM` and wires it up. The model itself is pretty straightforward, the main difference from other models is that it uses trained position embeddings and that all weight matrices are transposed compared to other models (due to the use of Conv1D in the upstream model). Fixes # (issue) ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [x] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --- docs/source/supported_models.md | 1 + .../test_flash_gpt2/test_flash_gpt2.json | 99 ++++ .../test_flash_gpt2/test_flash_gpt2_load.json | 398 +++++++++++++++ integration-tests/models/test_flash_gpt2.py | 44 ++ router/src/config.rs | 1 + .../custom_modeling/flash_gpt2_modeling.py | 454 ++++++++++++++++++ .../models/flash_gpt2.py | 78 +++ 7 files changed, 1075 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json create mode 100644 integration-tests/models/test_flash_gpt2.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py create mode 100644 server/text_generation_server/models/flash_gpt2.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index fa1f9f61..ceb25cfd 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -9,6 +9,7 @@ The following models are optimized and can be served with TGI, which uses custom - [BLOOM](https://huggingface.co/bigscience/bloom) - [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) - [Galactica](https://huggingface.co/facebook/galactica-120b) +- [GPT-2](https://huggingface.co/openai-community/gpt2) - [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) - [Llama](https://github.com/facebookresearch/llama) - [OPT](https://huggingface.co/facebook/opt-66b) diff --git a/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json new file mode 100644 index 00000000..ca7393a3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1835938, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.171875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6425781, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.7314453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.68603516, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.005393982, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.31079102, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08300781, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.58984375, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.953125, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0957031, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8095703, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9375, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json new file mode 100644 index 00000000..7bd15b90 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1835938, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.171875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6425781, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.7314453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.68603516, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.005672455, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3251953, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08294678, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5854492, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9423828, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0800781, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8369141, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0683594, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9711914, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1660156, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.1796875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6376953, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.72216797, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.7089844, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.0054779053, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3190918, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08319092, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5839844, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9506836, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0878906, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8496094, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9370117, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1660156, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.1796875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6376953, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.72216797, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.7089844, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.0054779053, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3190918, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08319092, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5839844, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9506836, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0878906, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8496094, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9370117, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2061, + "logprob": null, + "text": "What" + }, + { + "id": 318, + "logprob": -3.1660156, + "text": " is" + }, + { + "id": 2769, + "logprob": -9.1796875, + "text": " deep" + }, + { + "id": 4673, + "logprob": -1.6376953, + "text": " learning" + }, + { + "id": 30, + "logprob": -0.72216797, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 198, + "logprob": -0.7089844, + "special": false, + "text": "\n" + }, + { + "id": 198, + "logprob": -0.0054779053, + "special": false, + "text": "\n" + }, + { + "id": 29744, + "logprob": -0.3190918, + "special": false, + "text": "Deep" + }, + { + "id": 4673, + "logprob": -0.08319092, + "special": false, + "text": " learning" + }, + { + "id": 318, + "logprob": -0.5839844, + "special": false, + "text": " is" + }, + { + "id": 257, + "logprob": -0.9506836, + "special": false, + "text": " a" + }, + { + "id": 649, + "logprob": -2.0878906, + "special": false, + "text": " new" + }, + { + "id": 2214, + "logprob": -1.8496094, + "special": false, + "text": " field" + }, + { + "id": 286, + "logprob": -1.0673828, + "special": false, + "text": " of" + }, + { + "id": 2267, + "logprob": -0.9370117, + "special": false, + "text": " research" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new field of research" + } +] diff --git a/integration-tests/models/test_flash_gpt2.py b/integration-tests/models/test_flash_gpt2.py new file mode 100644 index 00000000..0c7977d0 --- /dev/null +++ b/integration-tests/models/test_flash_gpt2.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gpt2_handle(launcher): + with launcher("openai-community/gpt2", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gpt2(flash_gpt2_handle): + await flash_gpt2_handle.health(300) + return flash_gpt2_handle.client + + +@pytest.mark.asyncio +async def test_flash_gpt2(flash_gpt2, response_snapshot): + response = await flash_gpt2.generate( + "What is deep learning?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): + responses = await generate_load( + flash_gpt2, + "What is deep learning?", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert all( + [text == generated_texts[0] for text in generated_texts] + ), generated_texts + + assert responses == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 8640ede9..989f0e31 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -132,6 +132,7 @@ pub enum Config { Santacoder, Bloom, Mpt, + Gpt2, GptNeox, Phi, #[serde(rename = "phi-msft")] diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py new file mode 100644 index 00000000..d2599f7a --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) + + +def load_qkv(config, prefix: str, weights, head_size, num_heads): + if config.quantize == "gptq": + return _load_qkv_gptq( + config, + prefix, + weights, + ) + else: + return _load_qkv(config, prefix, weights, head_size, num_heads) + + +def _load_qkv_gptq(config, prefix: str, weights): + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + # Weights + weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize) + + # Bias + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + total_size = shape[0] + assert total_size % 3 == 0, f"Prepacked is not divisible by {3}" + single_size = total_size // 3 + assert single_size % world_size == 0 + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensors = [] + for i in range(3): + tensor = slice_[start + i * single_size : stop + i * single_size] + tensors.append(tensor) + bias = torch.cat(tensors, dim=0) + bias = bias.to(device=weights.device) + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +def _load_qkv(config, prefix: str, weights, head_size, num_heads): + """Load QKV from a single, transposed matrix.""" + + slice_ = weights._get_slice(f"{prefix}.c_attn.weight") + shape = slice_.get_shape() + total_size = shape[1] + assert total_size % 3 == 0, f"Prepacked is not divisible by {3}" + world_size = weights.process_group.size() + single_size = total_size // 3 + assert single_size % world_size == 0 + rank = weights.process_group.rank() + + # Weights + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensors = [] + for i in range(3): + tensor = slice_[:, start + i * single_size : stop + i * single_size] + tensors.append(tensor) + weight = torch.cat(tensors, dim=1).T + weight = weight.to(dtype=weights.dtype) + weight = weight.to(device=weights.device) + + # Bias + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + total_size = shape[0] + single_size = total_size // 3 + block_size = single_size // world_size + assert single_size % world_size == 0 + start = rank * block_size + stop = (rank + 1) * block_size + b = [] + for i in range(3): + tensor = slice_[start + i * single_size : stop + i * single_size] + b.append(tensor) + bias = torch.cat(b, dim=0) + bias = bias.to(dtype=weights.dtype) + bias = bias.to(device=weights.device) + assert list(bias.shape) == [ + 3 * num_heads * head_size + ], f"{weight.shape} != {[3 * num_heads * head_size]}" + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +def load_row(config, prefix: str, weights, bias: bool): + """load_row, but with transposed weight matrices.""" + + if config.quantize == "gptq": + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), process_group=weights.process_group + ) + + +def load_col(config, prefix: str, weights, bias: bool): + """load_col, but with transposed weight matrices.""" + if config.quantize == "gptq": + weight = weights.get_multi_weights_col( + [prefix], quantize=config.quantize, dim=1 + ) + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +class FlashGPT2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.query_key_value = load_qkv( + config, + prefix=prefix, + weights=weights, + head_size=self.head_size, + num_heads=self.num_heads, + ) + + self.o_proj = load_row( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + ) + + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + def forward( + self, + hidden_states, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + query, key, value = self.query_key_value(hidden_states).split( + self.head_size * self.num_heads, dim=1 + ) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + + paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + key, + value, + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GPT2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.activation_function + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.c_fc = load_col( + config, prefix=f"{prefix}.c_fc", weights=weights, bias=True + ) + self.c_proj = load_row( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=True, + ) + + intermediate_size = ( + config.n_inner if config.n_inner is not None else 4 * config.hidden_size + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + return self.c_proj(hidden_states) + + +class FlashGPT2Layer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.self_attn = FlashGPT2Attention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.ln_2", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + def forward( + self, + hidden_states, + residual, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_output = self.self_attn( + hidden_states, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states = attn_output + residual + residual = hidden_states + + hidden_states = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(hidden_states) + + return residual + mlp_output, residual + + +class FlashGPT2Model(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.layers = nn.ModuleList( + [ + FlashGPT2Layer( + prefix=( + f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}" + ), + config=config, + weights=weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.norm = nn.LayerNorm.load( + prefix="ln_f" if not prefix else f"{prefix}.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = inputs_embeds + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class FlashGPT2ForCausalLM(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=("wte" if not prefix else f"{prefix}.wte"), + weights=weights, + ) + self.embed_positions = TensorParallelEmbedding( + prefix=("wpe" if not prefix else f"{prefix}.wpe"), + weights=weights, + ) + + self.model = FlashGPT2Model(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="wte" if not prefix else f"{prefix}.wte", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + token_embeds = self.embed_tokens(input_ids) + position_embeds = self.embed_positions(position_ids) + inputs_embeds = token_embeds + position_embeds + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py new file mode 100644 index 00000000..5781f55e --- /dev/null +++ b/server/text_generation_server/models/flash_gpt2.py @@ -0,0 +1,78 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from transformers import AutoConfig, AutoTokenizer, GenerationConfig +from transformers.models.gpt2 import GPT2Tokenizer +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + +from text_generation_server.utils.import_utils import SYSTEM + + +class FlashGPT2(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + elif SYSTEM == "xpu": + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashGPT2 is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + prefix = "" + model = FlashGPT2ForCausalLM(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + super(FlashGPT2, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + )