mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix GPT2 detection.
This commit is contained in:
parent
ed95f1982d
commit
680a52f2f2
@ -181,7 +181,11 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in {"gpt_bigcode", "gpt2"}:
|
if (
|
||||||
|
model_type == "gpt_bigcode"
|
||||||
|
or model_type == "gpt2"
|
||||||
|
and model_id.startswith("bigcode/")
|
||||||
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashSantacoderSharded(
|
return FlashSantacoderSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user