Adding prefix test.

This commit is contained in:
Nicolas Patry 2024-09-05 10:33:50 +02:00
parent c1fe28d694
commit 3669d078e0
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 2851 additions and 10 deletions

View File

@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import ( from text_generation.types import (
BestOfSequence, BestOfSequence,
Message,
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
@ -97,7 +98,10 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool: ) -> bool:
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data: return _convert_data(data)
def _convert_data(data):
if isinstance(data, Dict):
choices = data["choices"] choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1: if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]: if "delta" in choices[0]:
@ -105,17 +109,10 @@ class ResponseComparator(JSONSnapshotExtension):
if "text" in choices[0]: if "text" in choices[0]:
return Completion(**data) return Completion(**data)
return ChatComplete(**data) return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
if isinstance(data, List): if isinstance(data, List):
if ( return [_convert_data(d) for d in data]
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data]
raise NotImplementedError raise NotImplementedError
def eq_token(token: Token, other: Token) -> bool: def eq_token(token: Token, other: Token) -> bool:
@ -400,6 +397,7 @@ def launcher(event_loop):
if not use_flash_attention: if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false" env["USE_FLASH_ATTENTION"] = "false"
env["USE_PREFIX_CACHING"] = "0"
with tempfile.TemporaryFile("w+") as tmp: with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe # We'll output stdout/stderr to a temporary file. Using a pipe
@ -571,3 +569,38 @@ def generate_load():
return await asyncio.gather(*futures) return await asyncio.gather(*futures)
return generate_load_inner return generate_load_inner
@pytest.fixture(scope="module")
def generate_multi():
async def generate_load_inner(
client: AsyncClient,
prompts: List[str],
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))
perm = np.random.permutation(arange)
rperm = [-1] * len(perm)
for i, p in enumerate(perm):
rperm[p] = i
shuffled_prompts = [prompts[p] for p in perm]
futures = [
client.chat(
messages=[Message(role="user", content=prompt)],
max_tokens=max_new_tokens,
temperature=0,
seed=seed,
)
for prompt in shuffled_prompts
]
shuffled_responses = await asyncio.gather(*futures)
responses = [shuffled_responses[p] for p in rperm]
return responses
return generate_load_inner

File diff suppressed because one or more lines are too long

View File

@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads. // Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); tracing::debug!(
"Input: {}",
&req.inputs.chars().take(1000).collect::<String>()
);
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;