mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
added llama
This commit is contained in:
parent
a3b7db932f
commit
2ae7c337a1
1314
server/poetry.lock
generated
1314
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -23,6 +23,7 @@ opentelemetry-api = "^1.15.0"
|
||||
opentelemetry-exporter-otlp = "^1.15.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
|
||||
[tool.poetry.extras]
|
||||
bnb = ["bitsandbytes"]
|
||||
|
@ -9,6 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.llama import Llama
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
|
||||
@ -22,6 +23,7 @@ __all__ = [
|
||||
"GPTNeoxSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
"Llama",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
]
|
||||
@ -69,6 +71,12 @@ def get_model(
|
||||
else:
|
||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||
|
||||
if config.model_type == "llama":
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for Llama")
|
||||
else:
|
||||
return Llama(model_id, revision, quantize)
|
||||
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
try:
|
||||
|
41
server/text_generation_server/models/llama.py
Normal file
41
server/text_generation_server/models/llama.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||
from typing import Optional
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
|
||||
class Llama(CausalLM):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left"
|
||||
)
|
||||
self.model = LlamaForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize,
|
||||
).eval()
|
||||
tokenizer.pad_token_id = (
|
||||
self.model.config.pad_token_id
|
||||
if self.model.config.pad_token_id is not None
|
||||
else self.model.config.eos_token_id
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
)
|
Loading…
Reference in New Issue
Block a user