feat: support for loading peft model

This commit is contained in:
Subaandh Krishnakumar 2023-08-29 11:06:57 +02:00
parent 7c2e0af2a6
commit d6d2b9426b
9 changed files with 172 additions and 48 deletions

View File

@ -100,6 +100,13 @@ struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)] #[clap(default_value = "bigscience/bloom-560m", long, env)]
model_id: String, model_id: String,
/// The name of the peft model to load.
/// Can be a PEFT_MODEL_ID as listed on <https://hf.co/models> like
/// Or it can be a local directory containing the necessary files (adapter_model.bin)
/// as saved by `save_pretrained(...)` methods of transformers
#[clap(long, env)]
peft_model_id: Option<String>,
/// The actual revision of the model if you're referring to a model /// The actual revision of the model if you're referring to a model
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`. /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
#[clap(long, env)] #[clap(long, env)]
@ -347,6 +354,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn shard_manager( fn shard_manager(
model_id: String, model_id: String,
peft_model_id: Option<String>,
revision: Option<String>, revision: Option<String>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
dtype: Option<Dtype>, dtype: Option<Dtype>,
@ -391,6 +399,10 @@ fn shard_manager(
"--json-output".to_string(), "--json-output".to_string(),
]; ];
if let Some(peft_model_id) = peft_model_id {
shard_args.push("--peft-model-id".to_string());
shard_args.push(peft_model_id)
}
// Activate trust remote code // Activate trust remote code
if trust_remote_code { if trust_remote_code {
shard_args.push("--trust-remote-code".to_string()); shard_args.push("--trust-remote-code".to_string());
@ -846,6 +858,7 @@ fn spawn_shards(
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
let peft_model_id = args.peft_model_id.clone();
let revision = args.revision.clone(); let revision = args.revision.clone();
let uds_path = args.shard_uds_path.clone(); let uds_path = args.shard_uds_path.clone();
let master_addr = args.master_addr.clone(); let master_addr = args.master_addr.clone();
@ -868,6 +881,7 @@ fn spawn_shards(
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
peft_model_id,
revision, revision,
quantize, quantize,
dtype, dtype,

View File

@ -27,6 +27,7 @@ class Dtype(str, Enum):
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
peft_model_id: str = None,
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
@ -78,8 +79,9 @@ def serve(
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
logger.info(f"{peft_model_id=}")
server.serve( server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_id
) )

View File

@ -75,6 +75,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
peft_model_id: str,
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
dtype = torch.float16 dtype = torch.float16
@ -85,6 +86,7 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
logger.info(f"In get models {peft_model_id=}")
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
model_id, model_id,
@ -120,6 +122,7 @@ def get_model(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
logger.info(f"Model type is {model_type=}")
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
if FLASH_ATTENTION: if FLASH_ATTENTION:
@ -217,20 +220,24 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
logger.error("In Flash RW")
return FlashRWSharded( return FlashRWSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
peft_model_id=peft_model_id,
) )
else: else:
logger.error(f"Value not none {peft_model_id=}")
return RW( return RW(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
peft_model_id=peft_model_id
) )
elif model_type == "opt": elif model_type == "opt":
@ -273,6 +280,7 @@ def get_model(
"4bit quantization is not supported for AutoModel" "4bit quantization is not supported for AutoModel"
) )
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
logger.error("In CausalLM")
return CausalLM( return CausalLM(
model_id, model_id,
revision, revision,

View File

@ -17,6 +17,8 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from transformers import BitsAndBytesConfig
from peft import PeftModel
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -483,6 +485,7 @@ class CausalLM(Model):
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
peft_model_id: str = None
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -494,6 +497,33 @@ class CausalLM(Model):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
if peft_model_id:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config
)
# load Lora weights
model = PeftModel.from_pretrained(
model,
peft_model_id,
device_map="auto",
)
else:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -501,6 +531,7 @@ class CausalLM(Model):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -511,6 +542,8 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
@ -15,6 +15,8 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from transformers import BitsAndBytesConfig
from peft import PeftModel
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -27,6 +29,7 @@ class FlashRWSharded(FlashCausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
peft_model_id: str = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -35,14 +38,6 @@ class FlashRWSharded(FlashCausalLM):
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained( config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
@ -61,6 +56,42 @@ class FlashRWSharded(FlashCausalLM):
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
if peft_model_id:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config
)
# load Lora weights
model = PeftModel.from_pretrained(
model,
peft_model_id,
device_map="auto",
)
model.eval()
else:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = FlashRWForCausalLM(config, weights) model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -1,6 +1,8 @@
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import PeftModel
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
@ -14,6 +16,7 @@ class RW(CausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
peft_model_id: str = None,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -25,6 +28,34 @@ class RW(CausalLM):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
if peft_model_id:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_enable_fp32_cpu_offload=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config
)
assert 1
# load Lora weights
model = PeftModel.from_pretrained(
model,
peft_model_id,
device_map="auto",
)
else:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -32,6 +63,7 @@ class RW(CausalLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,

View File

@ -123,7 +123,9 @@ def serve(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
peft_model_id: str
): ):
logger.info(f"Receiving peft {peft_model_id=}")
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -131,6 +133,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
peft_model_id: str = None,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded:
@ -144,8 +147,9 @@ def serve(
server_urls = [local_url] server_urls = [local_url]
try: try:
logger.exception(f"In server {peft_model_id=}")
model = get_model( model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
@ -193,5 +197,5 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_id)
) )

View File

@ -86,6 +86,7 @@ def weight_files(
# Local model # Local model
if Path(model_id).exists() and Path(model_id).is_dir(): if Path(model_id).exists() and Path(model_id).is_dir():
local_files = list(Path(model_id).glob(f"*{extension}")) local_files = list(Path(model_id).glob(f"*{extension}"))
logger.info(f"Local files {local_files=}")
if not local_files: if not local_files:
raise FileNotFoundError( raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}" f"No local weights found in {model_id} with extension {extension}"

View File

@ -31,7 +31,6 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
logger.info(f"Merging the lora weights.") logger.info(f"Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload() model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True) os.makedirs(model_id, exist_ok=True)