mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
loading models in 4 bit
this enables pascal GPUS to load LLama 2 when used with --quantize bitsandbytes
This commit is contained in:
parent
3062fa035d
commit
cf178a278a
@ -24,7 +24,6 @@ install: gen-server install-torch
|
|||||||
pip install -e ".[bnb, accelerate]"
|
pip install -e ".[bnb, accelerate]"
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
||||||
|
@ -141,10 +141,12 @@ def download_weights(
|
|||||||
if not extension == ".safetensors" or not auto_convert:
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
logger.warning("attempting to load local model")
|
||||||
# Try to see if there are local pytorch weights
|
# Try to see if there are local pytorch weights
|
||||||
try:
|
try:
|
||||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||||
|
print(local_pt_files)
|
||||||
|
|
||||||
# No local pytorch weights
|
# No local pytorch weights
|
||||||
except utils.LocalEntryNotFoundError:
|
except utils.LocalEntryNotFoundError:
|
||||||
|
@ -3,7 +3,7 @@ import inspect
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -481,11 +481,12 @@ class CausalLM(Model):
|
|||||||
device_map="auto"
|
device_map="auto"
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
else None,
|
else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_4bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
||||||
model = model.cuda()
|
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
|
# model = model.cuda()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
|
@ -109,6 +109,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
generations, next_batch = self.model.generate_token(batch)
|
generations, next_batch = self.model.generate_token(batch)
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
|
print(generations)
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
generations=[generation.to_pb() for generation in generations],
|
generations=[generation.to_pb() for generation in generations],
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from text_generation_server.utils.convert import convert_file, convert_files
|
from text_generation_server.utils.convert import convert_file, convert_files
|
||||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from text_generation_server.utils.peft import download_and_unload_peft
|
from text_generation_server.utils.peft import download_and_unload_peft, load_local_peft
|
||||||
from text_generation_server.utils.hub import (
|
from text_generation_server.utils.hub import (
|
||||||
weight_files,
|
weight_files,
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
@ -28,6 +28,7 @@ __all__ = [
|
|||||||
"weight_hub_files",
|
"weight_hub_files",
|
||||||
"download_weights",
|
"download_weights",
|
||||||
"download_and_unload_peft",
|
"download_and_unload_peft",
|
||||||
|
"load_local_peft",
|
||||||
"EntryNotFoundError",
|
"EntryNotFoundError",
|
||||||
"HeterogeneousNextTokenChooser",
|
"HeterogeneousNextTokenChooser",
|
||||||
"LocalEntryNotFoundError",
|
"LocalEntryNotFoundError",
|
||||||
|
@ -6,6 +6,8 @@ import torch
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||||
|
|
||||||
|
def load_local_peft(model_id, revision, trust_remote_code):
|
||||||
|
return model_id
|
||||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user