Merge pull request #1 from ohmytofu-ai/impl/4bit-demo

loading models in 4 bit
This commit is contained in:
chris-aeviator 2023-08-27 14:42:52 +02:00 committed by GitHub
commit 4d8e47e0e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 13 additions and 7 deletions

View File

@ -24,7 +24,6 @@ install: gen-server install-torch
pip install -e ".[bnb, accelerate]" pip install -e ".[bnb, accelerate]"
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
export-requirements: export-requirements:
poetry export -o requirements.txt -E bnb -E quantize --without-hashes poetry export -o requirements.txt -E bnb -E quantize --without-hashes

View File

@ -141,10 +141,12 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
logger.warning("attempting to load local model")
# Try to see if there are local pytorch weights # Try to see if there are local pytorch weights
try: try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin") local_pt_files = utils.weight_files(model_id, revision, ".bin")
print(local_pt_files)
# No local pytorch weights # No local pytorch weights
except utils.LocalEntryNotFoundError: except utils.LocalEntryNotFoundError:

View File

@ -3,7 +3,7 @@ import inspect
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -481,11 +481,12 @@ class CausalLM(Model):
device_map="auto" device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None, else None,
load_in_8bit=quantize == "bitsandbytes", load_in_4bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: ## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
model = model.cuda() # if torch.cuda.is_available() and torch.cuda.device_count() == 1:
# model = model.cuda()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:

View File

@ -109,6 +109,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
print(generations)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,

View File

@ -1,7 +1,7 @@
from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from text_generation_server.utils.peft import download_and_unload_peft from text_generation_server.utils.peft import download_and_unload_peft, load_local_peft
from text_generation_server.utils.hub import ( from text_generation_server.utils.hub import (
weight_files, weight_files,
weight_hub_files, weight_hub_files,
@ -28,6 +28,7 @@ __all__ = [
"weight_hub_files", "weight_hub_files",
"download_weights", "download_weights",
"download_and_unload_peft", "download_and_unload_peft",
"load_local_peft",
"EntryNotFoundError", "EntryNotFoundError",
"HeterogeneousNextTokenChooser", "HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError", "LocalEntryNotFoundError",

View File

@ -6,6 +6,8 @@ import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def load_local_peft(model_id, revision, trust_remote_code):
return model_id
def download_and_unload_peft(model_id, revision, trust_remote_code): def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16 torch_dtype = torch.float16