mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
works end-to-end
This commit is contained in:
parent
034e39185f
commit
ab0937b90c
@ -217,7 +217,5 @@ fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||
// Truncate to sequence_length
|
||||
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||
// Decode
|
||||
tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.unwrap()
|
||||
tokenizer.decode(encoding.get_ids(), false).unwrap()
|
||||
}
|
||||
|
@ -72,6 +72,14 @@ struct Args {
|
||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||
model_id: String,
|
||||
|
||||
/// The name of the adapter to load.
|
||||
/// Can be a MODEL_ID as listed on <https://hf.co/models>
|
||||
/// or it can be a local directory containing the necessary files
|
||||
/// as saved by `save_pretrained(...)` methods of transformers.
|
||||
/// Should be compatible with the model specified in `model_id`.
|
||||
#[clap(default_value = "", long, env)]
|
||||
adapter_id: String,
|
||||
|
||||
/// The actual revision of the model if you're referring to a model
|
||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||
#[clap(long, env)]
|
||||
@ -290,6 +298,7 @@ enum ShardStatus {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn shard_manager(
|
||||
model_id: String,
|
||||
adapter_id: String,
|
||||
revision: Option<String>,
|
||||
quantize: Option<Quantization>,
|
||||
dtype: Option<Dtype>,
|
||||
@ -332,6 +341,12 @@ fn shard_manager(
|
||||
"--json-output".to_string(),
|
||||
];
|
||||
|
||||
// Check if adapter id is non-empty string
|
||||
if !adapter_id.is_empty() {
|
||||
shard_args.push("--adapter-id".to_string());
|
||||
shard_args.push(adapter_id);
|
||||
}
|
||||
|
||||
// Activate trust remote code
|
||||
if trust_remote_code {
|
||||
shard_args.push("--trust-remote-code".to_string());
|
||||
@ -639,13 +654,13 @@ enum LauncherError {
|
||||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
fn download_convert_model(model_id: String, args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
|
||||
let mut download_args = vec![
|
||||
"download-weights".to_string(),
|
||||
args.model_id.to_string(),
|
||||
model_id,
|
||||
"--extension".to_string(),
|
||||
".safetensors".to_string(),
|
||||
"--logger-level".to_string(),
|
||||
@ -767,6 +782,7 @@ fn spawn_shards(
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
let adapter_id = args.adapter_id.clone();
|
||||
let revision = args.revision.clone();
|
||||
let uds_path = args.shard_uds_path.clone();
|
||||
let master_addr = args.master_addr.clone();
|
||||
@ -787,6 +803,7 @@ fn spawn_shards(
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
adapter_id,
|
||||
revision,
|
||||
quantize,
|
||||
dtype,
|
||||
@ -1081,7 +1098,13 @@ fn main() -> Result<(), LauncherError> {
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Download and convert model weights
|
||||
download_convert_model(&args, running.clone())?;
|
||||
download_convert_model(args.model_id.to_string(), &args, running.clone())?;
|
||||
|
||||
// check if adapter_id is non-empty string
|
||||
if !args.adapter_id.is_empty() {
|
||||
download_convert_model(args.adapter_id.to_string(), &args, running.clone())?;
|
||||
}
|
||||
|
||||
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
// Launcher was asked to stop
|
||||
|
@ -311,7 +311,7 @@ fn prepare_input(
|
||||
// truncate encoding and decode new inputs
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
let inputs = tokenizer
|
||||
.decode(Vec::from(encoding.get_ids()), false)
|
||||
.decode(encoding.get_ids(), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
(inputs, encoding.len())
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||
from typing import Optional
|
||||
@ -11,6 +12,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
LlamaConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
create_merged_weight_files,
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
@ -29,15 +31,6 @@ class FlashLlama(FlashCausalLM):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
print("ASDFASDF FLASHLLAMA INIT")
|
||||
print("Args:")
|
||||
print(f"model_id: {model_id}")
|
||||
print(f"adapter_id: {adapter_id}")
|
||||
print(f"revision: {revision}")
|
||||
print(f"quantize: {quantize}")
|
||||
print(f"dtype: {dtype}")
|
||||
print(f"trust_remote_code: {trust_remote_code}")
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
@ -70,7 +63,23 @@ class FlashLlama(FlashCausalLM):
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
||||
# get adapter filenames if adapter_id passed in
|
||||
merged_weight_filenames = None
|
||||
if len(adapter_id) > 0:
|
||||
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
|
||||
merged_weight_filenames = create_merged_weight_files(
|
||||
adapter_id, model_id, model_weight_filenames=filenames
|
||||
)
|
||||
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
merged_weight_filenames=merged_weight_filenames
|
||||
)
|
||||
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from text_generation_server.utils.adapter import create_merged_weight_files
|
||||
from text_generation_server.utils.convert import convert_file, convert_files
|
||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||
from text_generation_server.utils.weights import Weights
|
||||
@ -20,6 +21,7 @@ from text_generation_server.utils.tokens import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_merged_weight_files",
|
||||
"convert_file",
|
||||
"convert_files",
|
||||
"initialize_torch_distributed",
|
||||
|
127
server/text_generation_server/utils/adapter.py
Normal file
127
server/text_generation_server/utils/adapter.py
Normal file
@ -0,0 +1,127 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from peft import LoraConfig
|
||||
from peft.utils import transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from text_generation_server.utils.hub import weight_files
|
||||
|
||||
|
||||
def compute_delta_weight(
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
fan_in_fan_out: bool,
|
||||
alpha: float,
|
||||
r: float
|
||||
) -> torch.Tensor:
|
||||
"""Computes the delta weight for a Linear layer given A and B LoRA matrices.
|
||||
|
||||
TODO: add logic for other module types beyond Linear layers.
|
||||
|
||||
Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806
|
||||
"""
|
||||
scaling = alpha / r
|
||||
delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling
|
||||
return delta_weight
|
||||
|
||||
|
||||
def merge_adapter_weights(
|
||||
model_weights: Dict[str, torch.Tensor],
|
||||
adapter_weights: Dict[str, torch.Tensor],
|
||||
adapter_config: LoraConfig
|
||||
) -> Tuple[Dict[str, torch.Tensor], Set[str]]:
|
||||
"""Merges the adapter weights into the model weights."""
|
||||
module_mapping = defaultdict(dict)
|
||||
processed_adapter_weight_names = set()
|
||||
|
||||
# map the original tensor names to their adapter counterparts
|
||||
for weight_name in model_weights:
|
||||
end_idx = weight_name.rfind(".weight")
|
||||
key = weight_name[:end_idx]
|
||||
for adapter_weight_name in adapter_weights:
|
||||
if key in adapter_weight_name:
|
||||
# example value: 'base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight'
|
||||
# matrix_type gets the second to last element in the module name, i.e. 'lora_B'
|
||||
matrix_type = adapter_weight_name.split(".")[-2]
|
||||
module_mapping[weight_name][matrix_type] = adapter_weight_name
|
||||
processed_adapter_weight_names.add(adapter_weight_name)
|
||||
|
||||
# merge adapter weights into model weights
|
||||
merged_weights = {}
|
||||
for weight_name, adapter_weight_names in tqdm(
|
||||
module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)):
|
||||
|
||||
# TODO: support adapter types beyond LoRA
|
||||
lora_A = adapter_weights[adapter_weight_names["lora_A"]]
|
||||
lora_B = adapter_weights[adapter_weight_names["lora_B"]]
|
||||
delta_weight = compute_delta_weight(
|
||||
lora_A, lora_B, adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r)
|
||||
merged_weights[weight_name] = model_weights[weight_name] + delta_weight
|
||||
return merged_weights, processed_adapter_weight_names
|
||||
|
||||
|
||||
def create_merged_weight_files(
|
||||
adapter_id: str,
|
||||
model_id: str,
|
||||
model_weight_filenames: List[Path]
|
||||
) -> List[Path]:
|
||||
"""Creates merged weight files for the given adapter ID and filenames."""
|
||||
adapter_filenames = weight_files(adapter_id, extension=".safetensors")
|
||||
|
||||
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}")
|
||||
|
||||
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||
adapter_weights = {}
|
||||
for filename in adapter_filenames:
|
||||
adapter_weights.update(load_file(filename))
|
||||
remaining_adapter_weight_names = set(adapter_weights.keys())
|
||||
|
||||
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/"
|
||||
# just grab the existing files if they already exist and return immediately
|
||||
if os.path.exists(merged_weight_directory):
|
||||
logger.info("Merged weight files already exist, skipping merge computation.")
|
||||
return weight_files(merged_weight_directory)
|
||||
|
||||
os.makedirs(merged_weight_directory)
|
||||
merged_weight_filenames = []
|
||||
for filename in model_weight_filenames:
|
||||
model_weights = load_file(filename)
|
||||
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
|
||||
model_weights, adapter_weights, adapter_config)
|
||||
|
||||
merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename))
|
||||
save_file(merged_weights, merged_adapter_filename)
|
||||
logger.debug(f"Saved merged weights into {merged_adapter_filename}")
|
||||
|
||||
merged_weight_filenames.append(merged_adapter_filename)
|
||||
remaining_adapter_weight_names = remaining_adapter_weight_names.difference(
|
||||
processed_adapter_weight_names)
|
||||
|
||||
if len(remaining_adapter_weight_names) > 0:
|
||||
logger.warning("WARNING: The following lora weights were not merged into the model weights:")
|
||||
for lora_name in remaining_adapter_weight_names:
|
||||
logger.warning("\t" + lora_name)
|
||||
|
||||
return merged_weight_filenames
|
||||
|
||||
|
||||
def main():
|
||||
adapter_id = "arnavgrg/codealpaca-qlora"
|
||||
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
||||
model_id = adapter_config.base_model_name_or_path
|
||||
model_weight_filenames = weight_files(model_id, extension=".safetensors")
|
||||
|
||||
merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames)
|
||||
print(merged_adapter_filenames)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -15,19 +15,37 @@ class Weights:
|
||||
dtype,
|
||||
process_group,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
merged_weight_filenames: Optional[List] = None,
|
||||
):
|
||||
# idea: maybe we can pass in adapter filenames here and have these take
|
||||
# precedence over the model filenames? If so, then self.routing would
|
||||
# just handle the mapping of tensor names to filenames.
|
||||
# routes to adapter files take precedence over routes to main model files
|
||||
# to ensure that adapter weights are loaded instead of main model weights
|
||||
routing = {}
|
||||
if merged_weight_filenames is not None:
|
||||
for filename in merged_weight_filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple adapter files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
# set of keys that point to adapter files. Duplicates for these keys found
|
||||
# in main model files will be overridden.
|
||||
adapter_routes = set(routing.keys())
|
||||
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
if k in adapter_routes:
|
||||
logger.debug(f"Overriding main model weights with adapter weights for key: {k}")
|
||||
elif k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
f"Key {k} was found in multiple non-adapter files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
else:
|
||||
routing[k] = filename
|
||||
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.aliases = aliases
|
||||
|
Loading…
Reference in New Issue
Block a user