Support HF_TOKEN environment variable (#2066)

* Support HF_TOKEN environement variable

* Load test.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Lucain 2024-06-25 09:23:12 +02:00 committed by yuanwu
parent 4b25048b75
commit 931ff16c7a
5 changed files with 28 additions and 28 deletions

View File

@ -147,7 +147,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Downloading tokenizer");
// Parse Huggingface hub token
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime

View File

@ -2,13 +2,13 @@
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example:
If you're using the CLI, set the `HF_TOKEN` environment variable. For example:
```
export HUGGING_FACE_HUB_TOKEN=<YOUR READ TOKEN>
export HF_TOKEN=<YOUR READ TOKEN>
```
If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below.
If you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below.
```bash
model=meta-llama/Llama-2-7b-chat-hf
@ -17,7 +17,7 @@ token=<your READ token>
docker run --gpus all \
--shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \
-e HF_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $model

View File

@ -1,38 +1,38 @@
import sys
import subprocess
import contextlib
import pytest
import asyncio
import os
import docker
import contextlib
import json
import math
import os
import random
import re
import shutil
import subprocess
import sys
import tempfile
import time
import random
from typing import Dict, List, Optional
from docker.errors import NotFound
from typing import Optional, List, Dict
from syrupy.extensions.json import JSONSnapshotExtension
import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient
from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence,
Grammar,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
Completion,
Details,
Grammar,
InputToken,
Response,
Token,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
HF_TOKEN = os.getenv("HF_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
@ -447,8 +447,8 @@ def launcher(event_loop):
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
if HF_TOKEN is not None:
env["HF_TOKEN"] = HF_TOKEN
volumes = []
if DOCKER_VOLUME:

View File

@ -595,7 +595,7 @@ fn shard_manager(
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// Detect rope scaling
@ -929,7 +929,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// If args.weights_cache_override is some, pass it to the download process
@ -1231,7 +1231,7 @@ fn spawn_webserver(
// Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") {
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
envs.push(("HF_TOKEN".into(), api_token.into()))
};
// Parse Compute type

View File

@ -156,7 +156,7 @@ async fn main() -> Result<(), RouterError> {
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
// Tokenizer instance
// This will only be used to validate payloads