loading models in 4 bit

this enables pascal GPUS to load LLama 2 when used with --quantize bitsandbytes
This commit is contained in:
Chris 2023-08-27 14:38:25 +02:00
parent 3062fa035d
commit cf178a278a
6 changed files with 13 additions and 7 deletions

View File

@ -24,7 +24,6 @@ install: gen-server install-torch
pip install -e ".[bnb, accelerate]"
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
export-requirements:
poetry export -o requirements.txt -E bnb -E quantize --without-hashes

View File

@ -141,10 +141,12 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert:
raise e
logger.warning("attempting to load local model")
# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin")
print(local_pt_files)
# No local pytorch weights
except utils.LocalEntryNotFoundError:

View File

@ -3,7 +3,7 @@ import inspect
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
@ -481,11 +481,12 @@ class CausalLM(Model):
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
load_in_4bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
# model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:

View File

@ -109,6 +109,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
print(generations)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,

View File

@ -1,7 +1,7 @@
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.peft import download_and_unload_peft
from text_generation_server.utils.peft import download_and_unload_peft, load_local_peft
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
@ -28,6 +28,7 @@ __all__ = [
"weight_hub_files",
"download_weights",
"download_and_unload_peft",
"load_local_peft",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError",

View File

@ -6,6 +6,8 @@ import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def load_local_peft(model_id, revision, trust_remote_code):
return model_id
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16