mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
works end-to-end
This commit is contained in:
parent
034e39185f
commit
ab0937b90c
@ -217,7 +217,5 @@ fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
|||||||
// Truncate to sequence_length
|
// Truncate to sequence_length
|
||||||
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||||
// Decode
|
// Decode
|
||||||
tokenizer
|
tokenizer.decode(encoding.get_ids(), false).unwrap()
|
||||||
.decode(Vec::from(encoding.get_ids()), false)
|
|
||||||
.unwrap()
|
|
||||||
}
|
}
|
||||||
|
@ -72,6 +72,14 @@ struct Args {
|
|||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
|
/// The name of the adapter to load.
|
||||||
|
/// Can be a MODEL_ID as listed on <https://hf.co/models>
|
||||||
|
/// or it can be a local directory containing the necessary files
|
||||||
|
/// as saved by `save_pretrained(...)` methods of transformers.
|
||||||
|
/// Should be compatible with the model specified in `model_id`.
|
||||||
|
#[clap(default_value = "", long, env)]
|
||||||
|
adapter_id: String,
|
||||||
|
|
||||||
/// The actual revision of the model if you're referring to a model
|
/// The actual revision of the model if you're referring to a model
|
||||||
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
/// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -290,6 +298,7 @@ enum ShardStatus {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
adapter_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
@ -332,6 +341,12 @@ fn shard_manager(
|
|||||||
"--json-output".to_string(),
|
"--json-output".to_string(),
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Check if adapter id is non-empty string
|
||||||
|
if !adapter_id.is_empty() {
|
||||||
|
shard_args.push("--adapter-id".to_string());
|
||||||
|
shard_args.push(adapter_id);
|
||||||
|
}
|
||||||
|
|
||||||
// Activate trust remote code
|
// Activate trust remote code
|
||||||
if trust_remote_code {
|
if trust_remote_code {
|
||||||
shard_args.push("--trust-remote-code".to_string());
|
shard_args.push("--trust-remote-code".to_string());
|
||||||
@ -639,13 +654,13 @@ enum LauncherError {
|
|||||||
WebserverCannotStart,
|
WebserverCannotStart,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
fn download_convert_model(model_id: String, args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||||
// Enter download tracing span
|
// Enter download tracing span
|
||||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
|
|
||||||
let mut download_args = vec![
|
let mut download_args = vec![
|
||||||
"download-weights".to_string(),
|
"download-weights".to_string(),
|
||||||
args.model_id.to_string(),
|
model_id,
|
||||||
"--extension".to_string(),
|
"--extension".to_string(),
|
||||||
".safetensors".to_string(),
|
".safetensors".to_string(),
|
||||||
"--logger-level".to_string(),
|
"--logger-level".to_string(),
|
||||||
@ -767,6 +782,7 @@ fn spawn_shards(
|
|||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
|
let adapter_id = args.adapter_id.clone();
|
||||||
let revision = args.revision.clone();
|
let revision = args.revision.clone();
|
||||||
let uds_path = args.shard_uds_path.clone();
|
let uds_path = args.shard_uds_path.clone();
|
||||||
let master_addr = args.master_addr.clone();
|
let master_addr = args.master_addr.clone();
|
||||||
@ -787,6 +803,7 @@ fn spawn_shards(
|
|||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
|
adapter_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
dtype,
|
dtype,
|
||||||
@ -1081,7 +1098,13 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
// Download and convert model weights
|
// Download and convert model weights
|
||||||
download_convert_model(&args, running.clone())?;
|
download_convert_model(args.model_id.to_string(), &args, running.clone())?;
|
||||||
|
|
||||||
|
// check if adapter_id is non-empty string
|
||||||
|
if !args.adapter_id.is_empty() {
|
||||||
|
download_convert_model(args.adapter_id.to_string(), &args, running.clone())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if !running.load(Ordering::SeqCst) {
|
if !running.load(Ordering::SeqCst) {
|
||||||
// Launcher was asked to stop
|
// Launcher was asked to stop
|
||||||
|
@ -311,7 +311,7 @@ fn prepare_input(
|
|||||||
// truncate encoding and decode new inputs
|
// truncate encoding and decode new inputs
|
||||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
let inputs = tokenizer
|
let inputs = tokenizer
|
||||||
.decode(Vec::from(encoding.get_ids()), false)
|
.decode(encoding.get_ids(), false)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
(inputs, encoding.len())
|
(inputs, encoding.len())
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -11,6 +12,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
|
create_merged_weight_files,
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
@ -29,15 +31,6 @@ class FlashLlama(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
print("ASDFASDF FLASHLLAMA INIT")
|
|
||||||
print("Args:")
|
|
||||||
print(f"model_id: {model_id}")
|
|
||||||
print(f"adapter_id: {adapter_id}")
|
|
||||||
print(f"revision: {revision}")
|
|
||||||
print(f"quantize: {quantize}")
|
|
||||||
print(f"dtype: {dtype}")
|
|
||||||
print(f"trust_remote_code: {trust_remote_code}")
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
@ -70,7 +63,23 @@ class FlashLlama(FlashCausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
|
# get adapter filenames if adapter_id passed in
|
||||||
|
merged_weight_filenames = None
|
||||||
|
if len(adapter_id) > 0:
|
||||||
|
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
|
||||||
|
merged_weight_filenames = create_merged_weight_files(
|
||||||
|
adapter_id, model_id, model_weight_filenames=filenames
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
merged_weight_filenames=merged_weight_filenames
|
||||||
|
)
|
||||||
|
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.utils.adapter import create_merged_weight_files
|
||||||
from text_generation_server.utils.convert import convert_file, convert_files
|
from text_generation_server.utils.convert import convert_file, convert_files
|
||||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
@ -20,6 +21,7 @@ from text_generation_server.utils.tokens import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"create_merged_weight_files",
|
||||||
"convert_file",
|
"convert_file",
|
||||||
"convert_files",
|
"convert_files",
|
||||||
"initialize_torch_distributed",
|
"initialize_torch_distributed",
|
||||||
|
127
server/text_generation_server/utils/adapter.py
Normal file
127
server/text_generation_server/utils/adapter.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from peft import LoraConfig
|
||||||
|
from peft.utils import transpose
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from text_generation_server.utils.hub import weight_files
|
||||||
|
|
||||||
|
|
||||||
|
def compute_delta_weight(
|
||||||
|
lora_A: torch.Tensor,
|
||||||
|
lora_B: torch.Tensor,
|
||||||
|
fan_in_fan_out: bool,
|
||||||
|
alpha: float,
|
||||||
|
r: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Computes the delta weight for a Linear layer given A and B LoRA matrices.
|
||||||
|
|
||||||
|
TODO: add logic for other module types beyond Linear layers.
|
||||||
|
|
||||||
|
Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806
|
||||||
|
"""
|
||||||
|
scaling = alpha / r
|
||||||
|
delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling
|
||||||
|
return delta_weight
|
||||||
|
|
||||||
|
|
||||||
|
def merge_adapter_weights(
|
||||||
|
model_weights: Dict[str, torch.Tensor],
|
||||||
|
adapter_weights: Dict[str, torch.Tensor],
|
||||||
|
adapter_config: LoraConfig
|
||||||
|
) -> Tuple[Dict[str, torch.Tensor], Set[str]]:
|
||||||
|
"""Merges the adapter weights into the model weights."""
|
||||||
|
module_mapping = defaultdict(dict)
|
||||||
|
processed_adapter_weight_names = set()
|
||||||
|
|
||||||
|
# map the original tensor names to their adapter counterparts
|
||||||
|
for weight_name in model_weights:
|
||||||
|
end_idx = weight_name.rfind(".weight")
|
||||||
|
key = weight_name[:end_idx]
|
||||||
|
for adapter_weight_name in adapter_weights:
|
||||||
|
if key in adapter_weight_name:
|
||||||
|
# example value: 'base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight'
|
||||||
|
# matrix_type gets the second to last element in the module name, i.e. 'lora_B'
|
||||||
|
matrix_type = adapter_weight_name.split(".")[-2]
|
||||||
|
module_mapping[weight_name][matrix_type] = adapter_weight_name
|
||||||
|
processed_adapter_weight_names.add(adapter_weight_name)
|
||||||
|
|
||||||
|
# merge adapter weights into model weights
|
||||||
|
merged_weights = {}
|
||||||
|
for weight_name, adapter_weight_names in tqdm(
|
||||||
|
module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)):
|
||||||
|
|
||||||
|
# TODO: support adapter types beyond LoRA
|
||||||
|
lora_A = adapter_weights[adapter_weight_names["lora_A"]]
|
||||||
|
lora_B = adapter_weights[adapter_weight_names["lora_B"]]
|
||||||
|
delta_weight = compute_delta_weight(
|
||||||
|
lora_A, lora_B, adapter_config.fan_in_fan_out, adapter_config.lora_alpha, adapter_config.r)
|
||||||
|
merged_weights[weight_name] = model_weights[weight_name] + delta_weight
|
||||||
|
return merged_weights, processed_adapter_weight_names
|
||||||
|
|
||||||
|
|
||||||
|
def create_merged_weight_files(
|
||||||
|
adapter_id: str,
|
||||||
|
model_id: str,
|
||||||
|
model_weight_filenames: List[Path]
|
||||||
|
) -> List[Path]:
|
||||||
|
"""Creates merged weight files for the given adapter ID and filenames."""
|
||||||
|
adapter_filenames = weight_files(adapter_id, extension=".safetensors")
|
||||||
|
|
||||||
|
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
||||||
|
if adapter_config.base_model_name_or_path != model_id:
|
||||||
|
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}")
|
||||||
|
|
||||||
|
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||||
|
adapter_weights = {}
|
||||||
|
for filename in adapter_filenames:
|
||||||
|
adapter_weights.update(load_file(filename))
|
||||||
|
remaining_adapter_weight_names = set(adapter_weights.keys())
|
||||||
|
|
||||||
|
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/"
|
||||||
|
# just grab the existing files if they already exist and return immediately
|
||||||
|
if os.path.exists(merged_weight_directory):
|
||||||
|
logger.info("Merged weight files already exist, skipping merge computation.")
|
||||||
|
return weight_files(merged_weight_directory)
|
||||||
|
|
||||||
|
os.makedirs(merged_weight_directory)
|
||||||
|
merged_weight_filenames = []
|
||||||
|
for filename in model_weight_filenames:
|
||||||
|
model_weights = load_file(filename)
|
||||||
|
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
|
||||||
|
model_weights, adapter_weights, adapter_config)
|
||||||
|
|
||||||
|
merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename))
|
||||||
|
save_file(merged_weights, merged_adapter_filename)
|
||||||
|
logger.debug(f"Saved merged weights into {merged_adapter_filename}")
|
||||||
|
|
||||||
|
merged_weight_filenames.append(merged_adapter_filename)
|
||||||
|
remaining_adapter_weight_names = remaining_adapter_weight_names.difference(
|
||||||
|
processed_adapter_weight_names)
|
||||||
|
|
||||||
|
if len(remaining_adapter_weight_names) > 0:
|
||||||
|
logger.warning("WARNING: The following lora weights were not merged into the model weights:")
|
||||||
|
for lora_name in remaining_adapter_weight_names:
|
||||||
|
logger.warning("\t" + lora_name)
|
||||||
|
|
||||||
|
return merged_weight_filenames
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
adapter_id = "arnavgrg/codealpaca-qlora"
|
||||||
|
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
||||||
|
model_id = adapter_config.base_model_name_or_path
|
||||||
|
model_weight_filenames = weight_files(model_id, extension=".safetensors")
|
||||||
|
|
||||||
|
merged_adapter_filenames = create_merged_weight_files(adapter_id, model_id, model_weight_filenames)
|
||||||
|
print(merged_adapter_filenames)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -15,19 +15,37 @@ class Weights:
|
|||||||
dtype,
|
dtype,
|
||||||
process_group,
|
process_group,
|
||||||
aliases: Optional[Dict[str, List[str]]] = None,
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
|
merged_weight_filenames: Optional[List] = None,
|
||||||
):
|
):
|
||||||
# idea: maybe we can pass in adapter filenames here and have these take
|
# routes to adapter files take precedence over routes to main model files
|
||||||
# precedence over the model filenames? If so, then self.routing would
|
# to ensure that adapter weights are loaded instead of main model weights
|
||||||
# just handle the mapping of tensor names to filenames.
|
|
||||||
routing = {}
|
routing = {}
|
||||||
|
if merged_weight_filenames is not None:
|
||||||
|
for filename in merged_weight_filenames:
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple adapter files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
# set of keys that point to adapter files. Duplicates for these keys found
|
||||||
|
# in main model files will be overridden.
|
||||||
|
adapter_routes = set(routing.keys())
|
||||||
|
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if k in routing:
|
if k in adapter_routes:
|
||||||
|
logger.debug(f"Overriding main model weights with adapter weights for key: {k}")
|
||||||
|
elif k in routing:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
f"Key {k} was found in multiple non-adapter files: {filename} and {routing[k]}"
|
||||||
)
|
)
|
||||||
routing[k] = filename
|
else:
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
if aliases is None:
|
if aliases is None:
|
||||||
aliases = {}
|
aliases = {}
|
||||||
self.aliases = aliases
|
self.aliases = aliases
|
||||||
|
Loading…
Reference in New Issue
Block a user