mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Adding prefix test.
This commit is contained in:
parent
c1fe28d694
commit
3669d078e0
@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import (
|
||||
BestOfSequence,
|
||||
Message,
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionComplete,
|
||||
@ -97,7 +98,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
) -> bool:
|
||||
def convert_data(data):
|
||||
data = json.loads(data)
|
||||
if isinstance(data, Dict) and "choices" in data:
|
||||
return _convert_data(data)
|
||||
|
||||
def _convert_data(data):
|
||||
if isinstance(data, Dict):
|
||||
choices = data["choices"]
|
||||
if isinstance(choices, List) and len(choices) >= 1:
|
||||
if "delta" in choices[0]:
|
||||
@ -105,17 +109,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
if "text" in choices[0]:
|
||||
return Completion(**data)
|
||||
return ChatComplete(**data)
|
||||
|
||||
if isinstance(data, Dict):
|
||||
return Response(**data)
|
||||
if isinstance(data, List):
|
||||
if (
|
||||
len(data) > 0
|
||||
and "object" in data[0]
|
||||
and data[0]["object"] == "text_completion"
|
||||
):
|
||||
return [Completion(**d) for d in data]
|
||||
return [Response(**d) for d in data]
|
||||
return [_convert_data(d) for d in data]
|
||||
raise NotImplementedError
|
||||
|
||||
def eq_token(token: Token, other: Token) -> bool:
|
||||
@ -400,6 +397,7 @@ def launcher(event_loop):
|
||||
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
env["USE_PREFIX_CACHING"] = "0"
|
||||
|
||||
with tempfile.TemporaryFile("w+") as tmp:
|
||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||
@ -571,3 +569,38 @@ def generate_load():
|
||||
return await asyncio.gather(*futures)
|
||||
|
||||
return generate_load_inner
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def generate_multi():
|
||||
async def generate_load_inner(
|
||||
client: AsyncClient,
|
||||
prompts: List[str],
|
||||
max_new_tokens: int,
|
||||
seed: Optional[int] = None,
|
||||
) -> List[Response]:
|
||||
|
||||
import numpy as np
|
||||
|
||||
arange = np.arange(len(prompts))
|
||||
perm = np.random.permutation(arange)
|
||||
rperm = [-1] * len(perm)
|
||||
for i, p in enumerate(perm):
|
||||
rperm[p] = i
|
||||
|
||||
shuffled_prompts = [prompts[p] for p in perm]
|
||||
futures = [
|
||||
client.chat(
|
||||
messages=[Message(role="user", content=prompt)],
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=0,
|
||||
seed=seed,
|
||||
)
|
||||
for prompt in shuffled_prompts
|
||||
]
|
||||
|
||||
shuffled_responses = await asyncio.gather(*futures)
|
||||
responses = [shuffled_responses[p] for p in rperm]
|
||||
return responses
|
||||
|
||||
return generate_load_inner
|
||||
|
File diff suppressed because it is too large
Load Diff
229
integration-tests/models/test_flash_llama_prefix.py
Normal file
229
integration-tests/models/test_flash_llama_prefix.py
Normal file
File diff suppressed because one or more lines are too long
@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||
tracing::debug!(
|
||||
"Input: {}",
|
||||
&req.inputs.chars().take(1000).collect::<String>()
|
||||
);
|
||||
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
|
Loading…
Reference in New Issue
Block a user