From e4142e4fd57a1ae4b4de7b4bce0f0d59b903399f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 8 Mar 2023 10:49:18 +0100 Subject: [PATCH] feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES --- clients/python/README.md | 130 ++++++++++++++++++++++- clients/python/text_generation/client.py | 24 ++--- launcher/src/main.rs | 51 ++++++++- 3 files changed, 188 insertions(+), 17 deletions(-) diff --git a/clients/python/README.md b/clients/python/README.md index 414360bf..0f0b32f0 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,7 +1,8 @@ # Text Generation The Hugging Face Text Generation Python library provides a convenient way of interfacing with a -`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub. +`text-generation-inference` instance running on +[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub. ## Get Started @@ -11,7 +12,7 @@ The Hugging Face Text Generation Python library provides a convenient way of int pip install text-generation ``` -### Usage +### Inference API Usage ```python from text_generation import InferenceAPIClient @@ -50,3 +51,128 @@ async for response in client.generate_stream("Why is the sky blue?"): print(text) # ' Rayleigh scattering' ``` + +### Hugging Fae Inference Endpoint usage + +```python +from text_generation import Client + +endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" + +client = Client(endpoint_url) +text = client.generate("Why is the sky blue?").generated_text +print(text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +or with the asynchronous client: + +```python +from text_generation import AsyncClient + +endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" + +client = AsyncClient(endpoint_url) +response = await client.generate("Why is the sky blue?") +print(response.generated_text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +async for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +### Types + +```python +# Prompt tokens +class PrefillToken: + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + # Optional since the logprob of the first token cannot be computed + logprob: Optional[float] + + +# Generated tokens +class Token: + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + logprob: float + # Is the token a special token + # Can be used to ignore tokens when concatenating + special: bool + + +# Generation finish reason +class FinishReason(Enum): + # number of generated tokens == `max_new_tokens` + Length = "length" + # the model generated its end of sequence token + EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` + StopSequence = "stop_sequence" + + +# `generate` details +class Details: + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Prompt tokens + prefill: List[PrefillToken] + # Generated tokens + tokens: List[Token] + + +# `generate` return value +class Response: + # Generated text + generated_text: str + # Generation details + details: Details + + +# `generate_stream` details +class StreamDetails: + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + + +# `generate_stream` return value +class StreamResponse: + # Generated token + token: Token + # Complete generated text + # Only available when the generation is finished + generated_text: Optional[str] + # Generation details + # Only available when the generation is finished + details: Optional[StreamDetails] +``` \ No newline at end of file diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index d0a67915..3e9bbc36 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -63,7 +63,7 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -91,7 +91,7 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -109,7 +109,7 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -136,7 +136,7 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -164,7 +164,7 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -182,7 +182,7 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -268,7 +268,7 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -296,7 +296,7 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -314,7 +314,7 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -338,7 +338,7 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -366,7 +366,7 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -384,7 +384,7 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 66dcb2db..2eb4b8ff 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -23,8 +23,10 @@ struct Args { model_id: String, #[clap(long, env)] revision: Option, - #[clap(default_value = "1", long, env)] - num_shard: usize, + #[clap(long, env)] + sharded: Option, + #[clap(long, env)] + num_shard: Option, #[clap(long, env)] quantize: bool, #[clap(default_value = "128", long, env)] @@ -80,6 +82,7 @@ fn main() -> ExitCode { let Args { model_id, revision, + sharded, num_shard, quantize, max_concurrent_requests, @@ -102,13 +105,55 @@ fn main() -> ExitCode { watermark_delta, } = args; + // get the number of shards given `sharded` and `num_shard` + let num_shard = if let Some(sharded) = sharded { + // sharded is set + match sharded { + // sharded is set and true + true => { + match num_shard { + None => { + // try to default to the number of available GPUs + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); + let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES").expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); + let n_devices = cuda_visible_devices.split(",").count(); + if n_devices <= 1 { + panic!("`sharded` is true but only found {n_devices} CUDA devices"); + } + tracing::info!("Sharding on {n_devices} found CUDA devices"); + n_devices + } + Some(num_shard) => { + // we can't have only one shard while sharded + if num_shard <= 1 { + panic!("`sharded` is true but `num_shard` <= 1"); + } + num_shard + } + } + } + // sharded is set and false + false => { + let num_shard = num_shard.unwrap_or(1); + // we can't have more than one shard while not sharded + if num_shard != 1 { + panic!("`sharded` is false but `num_shard` != 1"); + } + num_shard + } + } + } else { + // default to a single shard + num_shard.unwrap_or(1) + }; + // Signal handler let running = Arc::new(AtomicBool::new(true)); let r = running.clone(); ctrlc::set_handler(move || { r.store(false, Ordering::SeqCst); }) - .expect("Error setting Ctrl-C handler"); + .expect("Error setting Ctrl-C handler"); // Check if model_id is a local model let local_path = Path::new(&model_id);