feat: support for loading peft model

This commit is contained in:
Subaandh Krishnakumar 2023-08-29 11:06:57 +02:00
parent 7c2e0af2a6
commit d6d2b9426b
9 changed files with 172 additions and 48 deletions

View File

@ -100,6 +100,13 @@ struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
model_id: String,
/// The name of the peft model to load.
/// Can be a PEFT_MODEL_ID as listed on <https://hf.co/models> like
/// Or it can be a local directory containing the necessary files (adapter_model.bin)
/// as saved by `save_pretrained(...)` methods of transformers
#[clap(long, env)]
peft_model_id: Option<String>,
/// The actual revision of the model if you're referring to a model
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
#[clap(long, env)]
@ -347,6 +354,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_id: String,
peft_model_id: Option<String>,
revision: Option<String>,
quantize: Option<Quantization>,
dtype: Option<Dtype>,
@ -391,6 +399,10 @@ fn shard_manager(
"--json-output".to_string(),
];
if let Some(peft_model_id) = peft_model_id {
shard_args.push("--peft-model-id".to_string());
shard_args.push(peft_model_id)
}
// Activate trust remote code
if trust_remote_code {
shard_args.push("--trust-remote-code".to_string());
@ -846,6 +858,7 @@ fn spawn_shards(
// Start shard processes
for rank in 0..num_shard {
let model_id = args.model_id.clone();
let peft_model_id = args.peft_model_id.clone();
let revision = args.revision.clone();
let uds_path = args.shard_uds_path.clone();
let master_addr = args.master_addr.clone();
@ -868,6 +881,7 @@ fn spawn_shards(
thread::spawn(move || {
shard_manager(
model_id,
peft_model_id,
revision,
quantize,
dtype,

View File

@ -27,6 +27,7 @@ class Dtype(str, Enum):
@app.command()
def serve(
model_id: str,
peft_model_id: str = None,
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
@ -78,8 +79,9 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
logger.info(f"{peft_model_id=}")
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_id
)

View File

@ -75,6 +75,7 @@ def get_model(
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
peft_model_id: str,
) -> Model:
if dtype is None:
dtype = torch.float16
@ -85,6 +86,7 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
logger.info(f"In get models {peft_model_id=}")
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,
@ -120,6 +122,7 @@ def get_model(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict["model_type"]
logger.info(f"Model type is {model_type=}")
if model_type == "gpt_bigcode":
if FLASH_ATTENTION:
@ -217,20 +220,24 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
logger.error("In Flash RW")
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
peft_model_id=peft_model_id,
)
else:
logger.error(f"Value not none {peft_model_id=}")
return RW(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
peft_model_id=peft_model_id
)
elif model_type == "opt":
@ -273,6 +280,7 @@ def get_model(
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
logger.error("In CausalLM")
return CausalLM(
model_id,
revision,

View File

@ -17,6 +17,8 @@ from text_generation_server.models.types import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from transformers import BitsAndBytesConfig
from peft import PeftModel
tracer = trace.get_tracer(__name__)
@ -483,6 +485,7 @@ class CausalLM(Model):
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
peft_model_id: str = None
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -494,6 +497,33 @@ class CausalLM(Model):
device = torch.device("cpu")
dtype = torch.float32
if peft_model_id:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config
)
# load Lora weights
model = PeftModel.from_pretrained(
model,
peft_model_id,
device_map="auto",
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
@ -501,6 +531,7 @@ class CausalLM(Model):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
@ -511,6 +542,8 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional
from text_generation_server.models import FlashCausalLM
@ -15,6 +15,8 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from transformers import BitsAndBytesConfig
from peft import PeftModel
tracer = trace.get_tracer(__name__)
@ -27,6 +29,7 @@ class FlashRWSharded(FlashCausalLM):
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
peft_model_id: str = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -35,14 +38,6 @@ class FlashRWSharded(FlashCausalLM):
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
@ -61,6 +56,42 @@ class FlashRWSharded(FlashCausalLM):
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
if peft_model_id:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config
)
# load Lora weights
model = PeftModel.from_pretrained(
model,
peft_model_id,
device_map="auto",
)
model.eval()
else:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)

View File

@ -1,6 +1,8 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import PeftModel
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
@ -14,6 +16,7 @@ class RW(CausalLM):
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
peft_model_id: str = None,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -25,6 +28,34 @@ class RW(CausalLM):
device = torch.device("cpu")
dtype = torch.float32
if peft_model_id:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config
)
assert 1
# load Lora weights
model = PeftModel.from_pretrained(
model,
peft_model_id,
device_map="auto",
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
@ -32,6 +63,7 @@ class RW(CausalLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,

View File

@ -123,7 +123,9 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
peft_model_id: str
):
logger.info(f"Receiving peft {peft_model_id=}")
async def serve_inner(
model_id: str,
revision: Optional[str],
@ -131,6 +133,7 @@ def serve(
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
peft_model_id: str = None,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@ -144,8 +147,9 @@ def serve(
server_urls = [local_url]
try:
logger.exception(f"In server {peft_model_id=}")
model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id
)
except Exception:
logger.exception("Error when initializing model")
@ -193,5 +197,5 @@ def serve(
await server.stop(0)
asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id)
)

View File

@ -86,6 +86,7 @@ def weight_files(
# Local model
if Path(model_id).exists() and Path(model_id).is_dir():
local_files = list(Path(model_id).glob(f"*{extension}"))
logger.info(f"Local files {local_files=}")
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"

View File

@ -31,7 +31,6 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
logger.info(f"Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)