mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Update mod.rs
feat(router): add automatic input token truncation to respect max_input_tokens limit
This commit is contained in:
parent
58848cb471
commit
7ca47777aa
@ -23,6 +23,10 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
|||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Backend {
|
pub trait Backend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
@ -48,6 +52,18 @@ pub struct Infer {
|
|||||||
backend_health: Arc<AtomicBool>,
|
backend_health: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let input_text = "Your input text here...";
|
||||||
|
let max_new_tokens = 100;
|
||||||
|
|
||||||
|
match send_request_to_tgi(input_text, max_new_tokens).await {
|
||||||
|
Ok(response) => println!("Generated Text: {}", response.generated_text),
|
||||||
|
Err(e) => eprintln!("Error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
@ -296,6 +312,45 @@ impl Infer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct RequestPayload {
|
||||||
|
inputs: String,
|
||||||
|
parameters: Parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Parameters {
|
||||||
|
max_new_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ResponseData {
|
||||||
|
generated_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_request_to_tgi(input_text: &str, max_new_tokens: u32) -> Result<ResponseData, Box<dyn Error>> {
|
||||||
|
let url = "http://localhost:8080/generate_stream";
|
||||||
|
|
||||||
|
let payload = RequestPayload {
|
||||||
|
inputs: input_text.to_string(),
|
||||||
|
parameters: Parameters { max_new_tokens },
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = Client::new();
|
||||||
|
let response = client
|
||||||
|
.post(url)
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(format!("Request failed with status: {}", response.status()).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_data = response.json::<ResponseData>().await?;
|
||||||
|
Ok(response_data)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GeneratedText {
|
pub struct GeneratedText {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
|
Loading…
Reference in New Issue
Block a user