mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Hotfixes for santacoder/bigcode.
This commit is contained in:
parent
b4aa87db58
commit
7e11c5d92b
@ -99,7 +99,10 @@ def get_model(
|
||||
else:
|
||||
return Galactica(model_id, revision, quantize=quantize)
|
||||
|
||||
if "bigcode" in model_id:
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if sharded:
|
||||
if not FLASH_ATTENTION:
|
||||
raise NotImplementedError(
|
||||
@ -110,9 +113,6 @@ def get_model(
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type == "bloom":
|
||||
if sharded:
|
||||
return BLOOMSharded(model_id, revision, quantize=quantize)
|
||||
|
@ -373,6 +373,9 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
|
Loading…
Reference in New Issue
Block a user