mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge pull request #1 from ohmytofu-ai/impl/4bit-demo
loading models in 4 bit
This commit is contained in:
commit
4d8e47e0e9
@ -24,7 +24,6 @@ install: gen-server install-torch
|
||||
pip install -e ".[bnb, accelerate]"
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
|
||||
export-requirements:
|
||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
||||
|
@ -141,10 +141,12 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
logger.warning("attempting to load local model")
|
||||
# Try to see if there are local pytorch weights
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||
print(local_pt_files)
|
||||
|
||||
# No local pytorch weights
|
||||
except utils.LocalEntryNotFoundError:
|
||||
|
@ -3,7 +3,7 @@ import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -481,11 +481,12 @@ class CausalLM(Model):
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
load_in_4bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
||||
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
# model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
|
@ -109,6 +109,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
print(generations)
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from text_generation_server.utils.convert import convert_file, convert_files
|
||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.utils.peft import download_and_unload_peft
|
||||
from text_generation_server.utils.peft import download_and_unload_peft, load_local_peft
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_files,
|
||||
weight_hub_files,
|
||||
@ -28,6 +28,7 @@ __all__ = [
|
||||
"weight_hub_files",
|
||||
"download_weights",
|
||||
"download_and_unload_peft",
|
||||
"load_local_peft",
|
||||
"EntryNotFoundError",
|
||||
"HeterogeneousNextTokenChooser",
|
||||
"LocalEntryNotFoundError",
|
||||
|
@ -6,6 +6,8 @@ import torch
|
||||
from transformers import AutoTokenizer
|
||||
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||
|
||||
def load_local_peft(model_id, revision, trust_remote_code):
|
||||
return model_id
|
||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
torch_dtype = torch.float16
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user