Update mod.rs

feat(router): add automatic input token truncation to respect max_input_tokens limit
This commit is contained in:
smith518 2024-10-15 11:22:34 +05:30 committed by GitHub
parent 58848cb471
commit 7ca47777aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,6 +23,10 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::error::Error;
#[async_trait]
pub trait Backend {
fn schedule(
@ -48,6 +52,18 @@ pub struct Infer {
backend_health: Arc<AtomicBool>,
}
#[tokio::main]
async fn main() {
let input_text = "Your input text here...";
let max_new_tokens = 100;
match send_request_to_tgi(input_text, max_new_tokens).await {
Ok(response) => println!("Generated Text: {}", response.generated_text),
Err(e) => eprintln!("Error: {}", e),
}
}
impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
@ -296,6 +312,45 @@ impl Infer {
}
}
#[derive(Serialize)]
struct RequestPayload {
inputs: String,
parameters: Parameters,
}
#[derive(Serialize)]
struct Parameters {
max_new_tokens: u32,
}
#[derive(Deserialize)]
struct ResponseData {
generated_text: String,
}
async fn send_request_to_tgi(input_text: &str, max_new_tokens: u32) -> Result<ResponseData, Box<dyn Error>> {
let url = "http://localhost:8080/generate_stream";
let payload = RequestPayload {
inputs: input_text.to_string(),
parameters: Parameters { max_new_tokens },
};
let client = Client::new();
let response = client
.post(url)
.json(&payload)
.send()
.await?;
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()).into());
}
let response_data = response.json::<ResponseData>().await?;
Ok(response_data)
}
#[derive(Debug)]
pub struct GeneratedText {
pub text: String,