text-generation-inference/server/text_generation_server/models/santacoder.py

48 lines
1.5 KiB
Python
Raw Normal View History

2023-01-20 11:24:39 +00:00
import torch
import torch.distributed
2023-02-14 12:02:16 +00:00
from typing import Optional, List
2023-01-20 11:24:39 +00:00
from transformers import AutoTokenizer, AutoModelForCausalLM
2023-03-07 17:52:22 +00:00
from text_generation_server.models import CausalLM
2023-01-20 11:24:39 +00:00
class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
2023-01-20 11:24:39 +00:00
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
2023-01-31 17:53:56 +00:00
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
2023-01-31 17:53:56 +00:00
)
2023-01-20 11:24:39 +00:00
self.model = (
AutoModelForCausalLM.from_pretrained(
model_id,
2023-01-31 17:53:56 +00:00
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize,
trust_remote_code=True, # required
)
.to(device)
.eval()
)
2023-01-20 11:24:39 +00:00
super(CausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)