mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix llama tokenizer
This commit is contained in:
parent
3c272aefc0
commit
7816a47697
57
server/poetry.lock
generated
57
server/poetry.lock
generated
@ -538,6 +538,19 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.3"
|
||||
description = "Fast and Customizable Tokenizers"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
@ -638,7 +651,7 @@ bnb = ["bitsandbytes"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "1963b706a875d5c9090fee0db7dc38be6770c6b037dc225b62c1d06537a5a69a"
|
||||
content-hash = "1c57379c7b9349d2a860b50b3ab125737a0f6f94f4303d7cb55248cb86ff8b8e"
|
||||
|
||||
[metadata.files]
|
||||
accelerate = [
|
||||
@ -1167,6 +1180,48 @@ setuptools = [
|
||||
{file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"},
|
||||
{file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"},
|
||||
]
|
||||
tokenizers = [
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"},
|
||||
{file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"},
|
||||
{file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"},
|
||||
{file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"},
|
||||
{file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"},
|
||||
{file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"},
|
||||
{file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"},
|
||||
]
|
||||
tomli = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
|
@ -24,6 +24,7 @@ opentelemetry-exporter-otlp = "^1.15.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "0.13.3"
|
||||
|
||||
[tool.poetry.extras]
|
||||
bnb = ["bitsandbytes"]
|
||||
|
@ -5,8 +5,9 @@ from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, List
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
from typing import Optional, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
@ -37,7 +38,7 @@ class FlashLlama(FlashCausalLM):
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashLlama does not support quantization")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left"
|
||||
)
|
||||
|
||||
@ -59,8 +60,10 @@ class FlashLlama(FlashCausalLM):
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
device,
|
||||
dtype
|
||||
)
|
||||
self.model = model.eval().to(device).to(dtype)
|
||||
self.model = model.eval()
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -71,10 +74,14 @@ class FlashLlama(FlashCausalLM):
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[Path],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device).to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
@ -130,6 +137,8 @@ class FlashLlama(FlashCausalLM):
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
|
||||
@ -149,7 +158,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashLlama does not support quantization")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left"
|
||||
)
|
||||
|
||||
@ -169,10 +178,11 @@ class FlashLlamaSharded(FlashLlama):
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
self.model = model.eval()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
@ -185,6 +195,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
@ -240,7 +251,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
|
@ -56,6 +56,8 @@ class FlashSantacoder(FlashCausalLM):
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
self.model = model.eval().to(device).to(dtype)
|
||||
|
||||
@ -68,10 +70,14 @@ class FlashSantacoder(FlashCausalLM):
|
||||
def load_weights(
|
||||
model: FlashSantacoderForCausalLM,
|
||||
filenames: List[Path],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device).to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
@ -141,6 +147,8 @@ class FlashSantacoder(FlashCausalLM):
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user