mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix main
This commit is contained in:
parent
e33183b118
commit
a0abfa278e
3
.github/workflows/build.yaml
vendored
3
.github/workflows/build.yaml
vendored
@ -194,7 +194,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v1
|
uses: actions/setup-python@v4.6
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: 3.9
|
||||||
- name: Tailscale
|
- name: Tailscale
|
||||||
@ -213,6 +213,7 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
|
export HUGGING_FACE_HUB_TOKEN={{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
make integration-tests
|
make integration-tests
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
|
2
Makefile
2
Makefile
@ -25,7 +25,7 @@ rust-tests: install-router install-launcher
|
|||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
integration-tests: install-integration-tests
|
integration-tests: install-integration-tests
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv -m "not private" integration-tests
|
||||||
|
|
||||||
update-integration-tests: install-integration-tests
|
update-integration-tests: install-integration-tests
|
||||||
pytest -s -vv --snapshot-update integration-tests
|
pytest -s -vv --snapshot-update integration-tests
|
||||||
|
@ -59,7 +59,7 @@ def launcher(event_loop):
|
|||||||
process.terminate()
|
process.terminate()
|
||||||
process.wait(60)
|
process.wait(60)
|
||||||
|
|
||||||
launcher_output = process.stdout.read1().decode("utf-8")
|
launcher_output = process.stdout.read().decode("utf-8")
|
||||||
print(launcher_output)
|
print(launcher_output)
|
||||||
|
|
||||||
process.stdout.close()
|
process.stdout.close()
|
||||||
|
@ -10,6 +10,7 @@ def flash_llama(launcher):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
async def test_flash_llama(flash_llama, snapshot):
|
async def test_flash_llama(flash_llama, snapshot):
|
||||||
await health_check(flash_llama, 120)
|
await health_check(flash_llama, 120)
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ async def test_flash_llama(flash_llama, snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
async def test_flash_llama_all_params(flash_llama, snapshot):
|
async def test_flash_llama_all_params(flash_llama, snapshot):
|
||||||
await health_check(flash_llama, 120)
|
await health_check(flash_llama, 120)
|
||||||
|
|
||||||
@ -43,6 +45,7 @@ async def test_flash_llama_all_params(flash_llama, snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
async def test_flash_llama_load(flash_llama, generate_load, snapshot):
|
async def test_flash_llama_load(flash_llama, generate_load, snapshot):
|
||||||
await health_check(flash_llama, 120)
|
await health_check(flash_llama, 120)
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ def flash_starcoder(launcher):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder(flash_starcoder, snapshot):
|
async def test_flash_starcoder(flash_starcoder, snapshot):
|
||||||
await health_check(flash_starcoder, 240)
|
await health_check(flash_starcoder, 240)
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ async def test_flash_starcoder(flash_starcoder, snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_default_params(flash_starcoder, snapshot):
|
async def test_flash_starcoder_default_params(flash_starcoder, snapshot):
|
||||||
await health_check(flash_starcoder, 240)
|
await health_check(flash_starcoder, 240)
|
||||||
|
|
||||||
@ -32,6 +34,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
|
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
|
||||||
await health_check(flash_starcoder, 240)
|
await health_check(flash_starcoder, 240)
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
full_name = f"transformer.{name}"
|
full_name = f"transformer.{name}"
|
||||||
|
@ -21,14 +21,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
import dropout_layer_norm
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -331,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||||
self.embed_tokens.add_null_idx()
|
self.embed_tokens.add_null_idx()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer: FlashLlamaLayer
|
layer: FlashLlamaLayer
|
||||||
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
layer.self_attn.query_key_value.prepare_weights(quantize)
|
||||||
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
layer.self_attn.o_proj.prepare_weights(quantize)
|
||||||
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
layer.mlp.gate_up_proj.prepare_weights(quantize)
|
||||||
layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
layer.mlp.down_proj.prepare_weights(quantize)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -428,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.model.post_load_weights(load_in_8bit)
|
self.model.post_load_weights(quantize)
|
||||||
self.lm_head.prepare_weights()
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -345,16 +345,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
self.head_size = self.layers[0].attention.head_size
|
self.head_size = self.layers[0].attention.head_size
|
||||||
self.num_heads = self.layers[0].attention.num_heads
|
self.num_heads = self.layers[0].attention.num_heads
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit=False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||||
self.embed_in.add_null_idx()
|
self.embed_in.add_null_idx()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer: FlashNeoXLayer
|
layer: FlashNeoXLayer
|
||||||
layer.attention.shuffle_qkv_dims()
|
layer.attention.shuffle_qkv_dims()
|
||||||
layer.attention.query_key_value.prepare_weights(load_in_8bit)
|
layer.attention.query_key_value.prepare_weights(quantize)
|
||||||
layer.attention.dense.prepare_weights(load_in_8bit)
|
layer.attention.dense.prepare_weights(quantize)
|
||||||
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
|
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
||||||
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
|
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@ -457,8 +457,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
config.hidden_size, config.vocab_size, bias=False
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit=False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.gpt_neox.post_load_weights(load_in_8bit)
|
self.gpt_neox.post_load_weights(quantize)
|
||||||
self.embed_out.prepare_weights()
|
self.embed_out.prepare_weights()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -261,16 +261,16 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
self.head_size = self.h[0].attn.head_size
|
self.head_size = self.h[0].attn.head_size
|
||||||
self.num_heads = self.h[0].attn.num_heads
|
self.num_heads = self.h[0].attn.num_heads
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
if self.tp_embeddings:
|
if self.tp_embeddings:
|
||||||
self.wte.add_null_idx()
|
self.wte.add_null_idx()
|
||||||
self.wpe.add_null_idx()
|
self.wpe.add_null_idx()
|
||||||
for layer in self.h:
|
for layer in self.h:
|
||||||
layer: Block
|
layer: Block
|
||||||
layer.attn.c_attn.prepare_weights(load_in_8bit)
|
layer.attn.c_attn.prepare_weights(quantize)
|
||||||
layer.attn.c_proj.prepare_weights(load_in_8bit)
|
layer.attn.c_proj.prepare_weights(quantize)
|
||||||
layer.mlp.c_fc.prepare_weights(load_in_8bit)
|
layer.mlp.c_fc.prepare_weights(quantize)
|
||||||
layer.mlp.c_proj.prepare_weights(load_in_8bit)
|
layer.mlp.c_proj.prepare_weights(quantize)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -347,8 +347,8 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def post_load_weights(self, load_in_8bit: bool = False):
|
def post_load_weights(self, quantize: Optional[str] = None):
|
||||||
self.transformer.post_load_weights(load_in_8bit)
|
self.transformer.post_load_weights(quantize)
|
||||||
self.lm_head.prepare_weights()
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -77,14 +77,14 @@ class FlashLlama(FlashCausalLM):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[Path],
|
filenames: List[Path],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
@ -204,7 +204,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -212,7 +212,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
):
|
):
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
@ -97,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
@ -89,7 +89,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -238,7 +238,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
):
|
):
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
slice_ = f.get_slice(key)
|
slice_ = f.get_slice(key)
|
||||||
|
@ -255,7 +255,7 @@ class GalacticaSharded(Galactica):
|
|||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
|
@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
@ -110,7 +110,7 @@ class OPTSharded(OPT):
|
|||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
|
@ -97,7 +97,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
@ -22,7 +24,7 @@ class FastLinear(nn.Linear):
|
|||||||
self.quantized = False
|
self.quantized = False
|
||||||
self.bnb_linear = None
|
self.bnb_linear = None
|
||||||
|
|
||||||
def prepare_weights(self, quantize: bool = False):
|
def prepare_weights(self, quantize: Optional[str] = None):
|
||||||
if quantize == "bitsandbytes":
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
Loading…
Reference in New Issue
Block a user