mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
feat: support for loading peft model
This commit is contained in:
parent
7c2e0af2a6
commit
d6d2b9426b
@ -100,6 +100,13 @@ struct Args {
|
||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||
model_id: String,
|
||||
|
||||
/// The name of the peft model to load.
|
||||
/// Can be a PEFT_MODEL_ID as listed on <https://hf.co/models> like
|
||||
/// Or it can be a local directory containing the necessary files (adapter_model.bin)
|
||||
/// as saved by `save_pretrained(...)` methods of transformers
|
||||
#[clap(long, env)]
|
||||
peft_model_id: Option<String>,
|
||||
|
||||
/// The actual revision of the model if you're referring to a model
|
||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||
#[clap(long, env)]
|
||||
@ -347,6 +354,7 @@ enum ShardStatus {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn shard_manager(
|
||||
model_id: String,
|
||||
peft_model_id: Option<String>,
|
||||
revision: Option<String>,
|
||||
quantize: Option<Quantization>,
|
||||
dtype: Option<Dtype>,
|
||||
@ -391,6 +399,10 @@ fn shard_manager(
|
||||
"--json-output".to_string(),
|
||||
];
|
||||
|
||||
if let Some(peft_model_id) = peft_model_id {
|
||||
shard_args.push("--peft-model-id".to_string());
|
||||
shard_args.push(peft_model_id)
|
||||
}
|
||||
// Activate trust remote code
|
||||
if trust_remote_code {
|
||||
shard_args.push("--trust-remote-code".to_string());
|
||||
@ -846,6 +858,7 @@ fn spawn_shards(
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
let peft_model_id = args.peft_model_id.clone();
|
||||
let revision = args.revision.clone();
|
||||
let uds_path = args.shard_uds_path.clone();
|
||||
let master_addr = args.master_addr.clone();
|
||||
@ -868,6 +881,7 @@ fn spawn_shards(
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
peft_model_id,
|
||||
revision,
|
||||
quantize,
|
||||
dtype,
|
||||
|
@ -27,6 +27,7 @@ class Dtype(str, Enum):
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
peft_model_id: str = None,
|
||||
revision: Optional[str] = None,
|
||||
sharded: bool = False,
|
||||
quantize: Optional[Quantization] = None,
|
||||
@ -78,8 +79,9 @@ def serve(
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
logger.info(f"{peft_model_id=}")
|
||||
server.serve(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_id
|
||||
)
|
||||
|
||||
|
||||
|
@ -75,6 +75,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
peft_model_id: str,
|
||||
) -> Model:
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
@ -85,6 +86,7 @@ def get_model(
|
||||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
logger.info(f"In get models {peft_model_id=}")
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
@ -120,6 +122,7 @@ def get_model(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_type = config_dict["model_type"]
|
||||
logger.info(f"Model type is {model_type=}")
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if FLASH_ATTENTION:
|
||||
@ -217,20 +220,24 @@ def get_model(
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||
else:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
logger.error("In Flash RW")
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
peft_model_id=peft_model_id,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Value not none {peft_model_id=}")
|
||||
return RW(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
peft_model_id=peft_model_id
|
||||
)
|
||||
|
||||
elif model_type == "opt":
|
||||
@ -273,6 +280,7 @@ def get_model(
|
||||
"4bit quantization is not supported for AutoModel"
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
logger.error("In CausalLM")
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -17,6 +17,8 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from transformers import BitsAndBytesConfig
|
||||
from peft import PeftModel
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -483,6 +485,7 @@ class CausalLM(Model):
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
peft_model_id: str = None
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
@ -494,6 +497,33 @@ class CausalLM(Model):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
if peft_model_id:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
# load Lora weights
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
peft_model_id,
|
||||
device_map="auto",
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
@ -501,6 +531,7 @@ class CausalLM(Model):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
@ -511,6 +542,8 @@ class CausalLM(Model):
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
@ -15,6 +15,8 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from transformers import BitsAndBytesConfig
|
||||
from peft import PeftModel
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -27,6 +29,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
peft_model_id: str = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -35,14 +38,6 @@ class FlashRWSharded(FlashCausalLM):
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
@ -61,6 +56,42 @@ class FlashRWSharded(FlashCausalLM):
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
if peft_model_id:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
# load Lora weights
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
peft_model_id,
|
||||
device_map="auto",
|
||||
)
|
||||
model.eval()
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import BitsAndBytesConfig
|
||||
from peft import PeftModel
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
@ -14,6 +16,7 @@ class RW(CausalLM):
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
peft_model_id: str = None,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
@ -25,6 +28,34 @@ class RW(CausalLM):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
if peft_model_id:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
llm_int8_enable_fp32_cpu_offload=True,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
assert 1
|
||||
# load Lora weights
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
peft_model_id,
|
||||
device_map="auto",
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
@ -32,6 +63,7 @@ class RW(CausalLM):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
|
@ -123,7 +123,9 @@ def serve(
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
peft_model_id: str
|
||||
):
|
||||
logger.info(f"Receiving peft {peft_model_id=}")
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
@ -131,6 +133,7 @@ def serve(
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
peft_model_id: str = None,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
@ -144,8 +147,9 @@ def serve(
|
||||
server_urls = [local_url]
|
||||
|
||||
try:
|
||||
logger.exception(f"In server {peft_model_id=}")
|
||||
model = get_model(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
@ -193,5 +197,5 @@ def serve(
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id)
|
||||
)
|
||||
|
@ -86,6 +86,7 @@ def weight_files(
|
||||
# Local model
|
||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||
local_files = list(Path(model_id).glob(f"*{extension}"))
|
||||
logger.info(f"Local files {local_files=}")
|
||||
if not local_files:
|
||||
raise FileNotFoundError(
|
||||
f"No local weights found in {model_id} with extension {extension}"
|
||||
|
@ -31,7 +31,6 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
logger.info(f"Merging the lora weights.")
|
||||
|
||||
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||
|
||||
model = model.merge_and_unload()
|
||||
|
||||
os.makedirs(model_id, exist_ok=True)
|
||||
|
Loading…
Reference in New Issue
Block a user