added llama

This commit is contained in:
Yannic Kilcher 2023-03-19 09:31:47 +01:00
parent a3b7db932f
commit 2ae7c337a1
4 changed files with 730 additions and 634 deletions

1314
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -23,6 +23,7 @@ opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
[tool.poetry.extras]
bnb = ["bitsandbytes"]

View File

@ -9,6 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.llama import Llama
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
@ -22,6 +23,7 @@ __all__ = [
"GPTNeoxSharded",
"Seq2SeqLM",
"SantaCoder",
"Llama",
"T5Sharded",
"get_model",
]
@ -69,6 +71,12 @@ def get_model(
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
if config.model_type == "llama":
if sharded:
raise ValueError("sharded is not supported for Llama")
else:
return Llama(model_id, revision, quantize)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:

View File

@ -0,0 +1,41 @@
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from typing import Optional
from text_generation_server.models import CausalLM
class Llama(CausalLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = LlamaTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
self.model = LlamaForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval()
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)