mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
feat: support for loading peft model
This commit is contained in:
parent
7c2e0af2a6
commit
d6d2b9426b
@ -100,6 +100,13 @@ struct Args {
|
|||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
|
/// The name of the peft model to load.
|
||||||
|
/// Can be a PEFT_MODEL_ID as listed on <https://hf.co/models> like
|
||||||
|
/// Or it can be a local directory containing the necessary files (adapter_model.bin)
|
||||||
|
/// as saved by `save_pretrained(...)` methods of transformers
|
||||||
|
#[clap(long, env)]
|
||||||
|
peft_model_id: Option<String>,
|
||||||
|
|
||||||
/// The actual revision of the model if you're referring to a model
|
/// The actual revision of the model if you're referring to a model
|
||||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -347,6 +354,7 @@ enum ShardStatus {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
peft_model_id: Option<String>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
@ -391,6 +399,10 @@ fn shard_manager(
|
|||||||
"--json-output".to_string(),
|
"--json-output".to_string(),
|
||||||
];
|
];
|
||||||
|
|
||||||
|
if let Some(peft_model_id) = peft_model_id {
|
||||||
|
shard_args.push("--peft-model-id".to_string());
|
||||||
|
shard_args.push(peft_model_id)
|
||||||
|
}
|
||||||
// Activate trust remote code
|
// Activate trust remote code
|
||||||
if trust_remote_code {
|
if trust_remote_code {
|
||||||
shard_args.push("--trust-remote-code".to_string());
|
shard_args.push("--trust-remote-code".to_string());
|
||||||
@ -846,6 +858,7 @@ fn spawn_shards(
|
|||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
|
let peft_model_id = args.peft_model_id.clone();
|
||||||
let revision = args.revision.clone();
|
let revision = args.revision.clone();
|
||||||
let uds_path = args.shard_uds_path.clone();
|
let uds_path = args.shard_uds_path.clone();
|
||||||
let master_addr = args.master_addr.clone();
|
let master_addr = args.master_addr.clone();
|
||||||
@ -868,6 +881,7 @@ fn spawn_shards(
|
|||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
|
peft_model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
dtype,
|
dtype,
|
||||||
|
@ -27,6 +27,7 @@ class Dtype(str, Enum):
|
|||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
peft_model_id: str = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
@ -78,8 +79,9 @@ def serve(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
|
logger.info(f"{peft_model_id=}")
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,6 +75,7 @@ def get_model(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
|
peft_model_id: str,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
@ -85,6 +86,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
|
logger.info(f"In get models {peft_model_id=}")
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
return GalacticaSharded(
|
return GalacticaSharded(
|
||||||
model_id,
|
model_id,
|
||||||
@ -120,6 +122,7 @@ def get_model(
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
logger.info(f"Model type is {model_type=}")
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
@ -217,20 +220,24 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||||
|
logger.error("In Flash RW")
|
||||||
return FlashRWSharded(
|
return FlashRWSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
peft_model_id=peft_model_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
logger.error(f"Value not none {peft_model_id=}")
|
||||||
return RW(
|
return RW(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
peft_model_id=peft_model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "opt":
|
elif model_type == "opt":
|
||||||
@ -273,6 +280,7 @@ def get_model(
|
|||||||
"4bit quantization is not supported for AutoModel"
|
"4bit quantization is not supported for AutoModel"
|
||||||
)
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
logger.error("In CausalLM")
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -17,6 +17,8 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -483,6 +485,7 @@ class CausalLM(Model):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
peft_model_id: str = None
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -494,6 +497,33 @@ class CausalLM(Model):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
if peft_model_id:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
llm_int8_enable_fp32_cpu_offload=True,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# load Lora weights
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
peft_model_id,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -501,6 +531,7 @@ class CausalLM(Model):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -511,6 +542,8 @@ class CausalLM(Model):
|
|||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
@ -15,6 +15,8 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
peft_model_id: str = None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -35,14 +38,6 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = RWConfig.from_pretrained(
|
config = RWConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
@ -61,6 +56,42 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
|
if peft_model_id:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
llm_int8_enable_fp32_cpu_offload=True,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# load Lora weights
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
peft_model_id,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
model = FlashRWForCausalLM(config, weights)
|
model = FlashRWForCausalLM(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
from peft import PeftModel
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
@ -14,6 +16,7 @@ class RW(CausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
peft_model_id: str = None,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -25,6 +28,34 @@ class RW(CausalLM):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
if peft_model_id:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
llm_int8_enable_fp32_cpu_offload=True,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
quantization_config=bnb_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 1
|
||||||
|
# load Lora weights
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
peft_model_id,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -32,6 +63,7 @@ class RW(CausalLM):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
@ -123,7 +123,9 @@ def serve(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
|
peft_model_id: str
|
||||||
):
|
):
|
||||||
|
logger.info(f"Receiving peft {peft_model_id=}")
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -131,6 +133,7 @@ def serve(
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
peft_model_id: str = None,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -144,8 +147,9 @@ def serve(
|
|||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.exception(f"In server {peft_model_id=}")
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
@ -193,5 +197,5 @@ def serve(
|
|||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id)
|
||||||
)
|
)
|
||||||
|
@ -86,6 +86,7 @@ def weight_files(
|
|||||||
# Local model
|
# Local model
|
||||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||||
local_files = list(Path(model_id).glob(f"*{extension}"))
|
local_files = list(Path(model_id).glob(f"*{extension}"))
|
||||||
|
logger.info(f"Local files {local_files=}")
|
||||||
if not local_files:
|
if not local_files:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No local weights found in {model_id} with extension {extension}"
|
f"No local weights found in {model_id} with extension {extension}"
|
||||||
|
@ -31,7 +31,6 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||||||
logger.info(f"Merging the lora weights.")
|
logger.info(f"Merging the lora weights.")
|
||||||
|
|
||||||
base_model_id = model.peft_config["default"].base_model_name_or_path
|
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||||
|
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
os.makedirs(model_id, exist_ok=True)
|
os.makedirs(model_id, exist_ok=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user