feat(server): support hf endpoint weight layout (#266)

This commit is contained in:
OlivierDehaene 2023-05-03 11:36:24 +02:00 committed by GitHub
parent 4096000e34
commit 85aa7e2e7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 146 additions and 26 deletions

View File

@ -512,7 +512,11 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(
args: &Args,
auto_convert: bool,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
let mut download_argv = vec![ let mut download_argv = vec![
"text-generation-server".to_string(), "text-generation-server".to_string(),
"download-weights".to_string(), "download-weights".to_string(),
@ -524,6 +528,11 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
"--json-output".to_string(), "--json-output".to_string(),
]; ];
// Auto convert weights to safetensors
if auto_convert {
download_argv.push("--auto-convert".to_string());
}
// Model optional revision // Model optional revision
if let Some(revision) = &args.revision { if let Some(revision) = &args.revision {
download_argv.push("--revision".to_string()); download_argv.push("--revision".to_string());
@ -855,14 +864,11 @@ fn main() -> Result<(), LauncherError> {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Check if model_id is a local model // auto_convert is only needed for sharded models as we do not require safetensors in
let local_path = Path::new(&args.model_id); // single shard mode
let is_local_model = local_path.exists() && local_path.is_dir(); let auto_convert = num_shard > 1;
// Download and convert model weights
// Download weights for sharded models download_convert_model(&args, auto_convert, running.clone())?;
if !is_local_model && args.weights_cache_override.is_none() && num_shard > 1 {
download_model(&args, running.clone())?;
}
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false)); let shutdown = Arc::new(Mutex::new(false));

View File

@ -63,6 +63,7 @@ def download_weights(
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
extension: str = ".safetensors", extension: str = ".safetensors",
auto_convert: bool = True,
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
): ):
@ -84,31 +85,55 @@ def download_weights(
# Test if files were already download # Test if files were already download
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)
logger.info( logger.info("Files are already present on the host. " "Skipping download.")
"Files are already present in the local cache. " "Skipping download."
)
return return
# Local files not found # Local files not found
except utils.LocalEntryNotFoundError: except (utils.LocalEntryNotFoundError, FileNotFoundError):
pass pass
# Download weights directly is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
if not is_local_model:
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
# Successfully downloaded weights
return
# No weights found on the hub with this extension
except utils.EntryNotFoundError as e:
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if not extension == ".safetensors" or not auto_convert:
raise e
# Try to see if there are local pytorch weights
try: try:
filenames = utils.weight_hub_files(model_id, revision, extension) # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
utils.download_weights(filenames, model_id, revision) local_pt_files = utils.weight_files(model_id, revision, ".bin")
except utils.EntryNotFoundError as e:
if not extension == ".safetensors":
raise e
logger.warning( # No local pytorch weights
f"No safetensors weights found for model {model_id} at revision {revision}. " except utils.LocalEntryNotFoundError:
f"Converting PyTorch weights instead." if extension == ".safetensors":
) logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Downloading PyTorch weights."
)
# Try to see if there are pytorch weights # Try to see if there are pytorch weights on the hub
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights # Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision) local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
if auto_convert:
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)
# Safetensors final filenames
local_st_files = [ local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files for p in local_pt_files

View File

@ -223,6 +223,15 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight": if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):

View File

@ -139,6 +139,15 @@ class FlashLlama(FlashCausalLM):
del value del value
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) model.post_load_weights(quantize)
@ -300,5 +309,15 @@ class FlashLlamaSharded(FlashLlama):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) model.post_load_weights(quantize)

View File

@ -149,4 +149,14 @@ class FlashNeoXSharded(FlashNeoX):
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
model.post_load_weights(quantize) model.post_load_weights(quantize)

View File

@ -372,5 +372,15 @@ class FlashSantacoderSharded(FlashSantacoder):
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) model.post_load_weights(quantize)

View File

@ -355,6 +355,15 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):

View File

@ -205,6 +205,15 @@ class GPTNeoxSharded(CausalLM):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):

View File

@ -210,6 +210,15 @@ class OPTSharded(OPT):
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):

View File

@ -211,6 +211,15 @@ class T5Sharded(Seq2SeqLM):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
def forward( def forward(
self, self,
input_ids, input_ids,

View File

@ -77,7 +77,12 @@ def weight_files(
"""Get the local files""" """Get the local files"""
# Local model # Local model
if Path(model_id).exists() and Path(model_id).is_dir(): if Path(model_id).exists() and Path(model_id).is_dir():
return list(Path(model_id).glob(f"*{extension}")) local_files = list(Path(model_id).glob(f"*{extension}"))
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"
)
return local_files
try: try:
filenames = weight_hub_files(model_id, revision, extension) filenames = weight_hub_files(model_id, revision, extension)
@ -98,7 +103,7 @@ def weight_files(
for filename in filenames: for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists(): if not p.exists():
raise LocalEntryNotFoundError( raise FileNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
) )
files.append(p) files.append(p)