mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
added llama
This commit is contained in:
parent
a3b7db932f
commit
2ae7c337a1
1314
server/poetry.lock
generated
1314
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -23,6 +23,7 @@ opentelemetry-api = "^1.15.0"
|
|||||||
opentelemetry-exporter-otlp = "^1.15.0"
|
opentelemetry-exporter-otlp = "^1.15.0"
|
||||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
|
sentencepiece = "^0.1.97"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
|
@ -9,6 +9,7 @@ from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
|||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
|
from text_generation_server.models.llama import Llama
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ __all__ = [
|
|||||||
"GPTNeoxSharded",
|
"GPTNeoxSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
"SantaCoder",
|
||||||
|
"Llama",
|
||||||
"T5Sharded",
|
"T5Sharded",
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
@ -69,6 +71,12 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
|
if config.model_type == "llama":
|
||||||
|
if sharded:
|
||||||
|
raise ValueError("sharded is not supported for Llama")
|
||||||
|
else:
|
||||||
|
return Llama(model_id, revision, quantize)
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
try:
|
try:
|
||||||
|
41
server/text_generation_server/models/llama.py
Normal file
41
server/text_generation_server/models/llama.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer, LlamaForCausalLM
|
||||||
|
from typing import Optional
|
||||||
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class Llama(CausalLM):
|
||||||
|
def __init__(
|
||||||
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
|
model_id, revision=revision, padding_side="left"
|
||||||
|
)
|
||||||
|
self.model = LlamaForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
load_in_8bit=quantize,
|
||||||
|
).eval()
|
||||||
|
tokenizer.pad_token_id = (
|
||||||
|
self.model.config.pad_token_id
|
||||||
|
if self.model.config.pad_token_id is not None
|
||||||
|
else self.model.config.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
super(CausalLM, self).__init__(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=device,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user