mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Adding prefix test.
This commit is contained in:
parent
c1fe28d694
commit
3669d078e0
@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
|||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
|
Message,
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionComplete,
|
ChatCompletionComplete,
|
||||||
@ -97,7 +98,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
def convert_data(data):
|
def convert_data(data):
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if isinstance(data, Dict) and "choices" in data:
|
return _convert_data(data)
|
||||||
|
|
||||||
|
def _convert_data(data):
|
||||||
|
if isinstance(data, Dict):
|
||||||
choices = data["choices"]
|
choices = data["choices"]
|
||||||
if isinstance(choices, List) and len(choices) >= 1:
|
if isinstance(choices, List) and len(choices) >= 1:
|
||||||
if "delta" in choices[0]:
|
if "delta" in choices[0]:
|
||||||
@ -105,17 +109,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
if "text" in choices[0]:
|
if "text" in choices[0]:
|
||||||
return Completion(**data)
|
return Completion(**data)
|
||||||
return ChatComplete(**data)
|
return ChatComplete(**data)
|
||||||
|
|
||||||
if isinstance(data, Dict):
|
if isinstance(data, Dict):
|
||||||
return Response(**data)
|
return Response(**data)
|
||||||
if isinstance(data, List):
|
if isinstance(data, List):
|
||||||
if (
|
return [_convert_data(d) for d in data]
|
||||||
len(data) > 0
|
|
||||||
and "object" in data[0]
|
|
||||||
and data[0]["object"] == "text_completion"
|
|
||||||
):
|
|
||||||
return [Completion(**d) for d in data]
|
|
||||||
return [Response(**d) for d in data]
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def eq_token(token: Token, other: Token) -> bool:
|
def eq_token(token: Token, other: Token) -> bool:
|
||||||
@ -400,6 +397,7 @@ def launcher(event_loop):
|
|||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
env["USE_PREFIX_CACHING"] = "0"
|
||||||
|
|
||||||
with tempfile.TemporaryFile("w+") as tmp:
|
with tempfile.TemporaryFile("w+") as tmp:
|
||||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||||
@ -571,3 +569,38 @@ def generate_load():
|
|||||||
return await asyncio.gather(*futures)
|
return await asyncio.gather(*futures)
|
||||||
|
|
||||||
return generate_load_inner
|
return generate_load_inner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def generate_multi():
|
||||||
|
async def generate_load_inner(
|
||||||
|
client: AsyncClient,
|
||||||
|
prompts: List[str],
|
||||||
|
max_new_tokens: int,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> List[Response]:
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
arange = np.arange(len(prompts))
|
||||||
|
perm = np.random.permutation(arange)
|
||||||
|
rperm = [-1] * len(perm)
|
||||||
|
for i, p in enumerate(perm):
|
||||||
|
rperm[p] = i
|
||||||
|
|
||||||
|
shuffled_prompts = [prompts[p] for p in perm]
|
||||||
|
futures = [
|
||||||
|
client.chat(
|
||||||
|
messages=[Message(role="user", content=prompt)],
|
||||||
|
max_tokens=max_new_tokens,
|
||||||
|
temperature=0,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
for prompt in shuffled_prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
shuffled_responses = await asyncio.gather(*futures)
|
||||||
|
responses = [shuffled_responses[p] for p in rperm]
|
||||||
|
return responses
|
||||||
|
|
||||||
|
return generate_load_inner
|
||||||
|
File diff suppressed because it is too large
Load Diff
229
integration-tests/models/test_flash_llama_prefix.py
Normal file
229
integration-tests/models/test_flash_llama_prefix.py
Normal file
File diff suppressed because one or more lines are too long
@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
|
|||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
// Do not long ultra long inputs, like image payloads.
|
// Do not long ultra long inputs, like image payloads.
|
||||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
tracing::debug!(
|
||||||
|
"Input: {}",
|
||||||
|
&req.inputs.chars().take(1000).collect::<String>()
|
||||||
|
);
|
||||||
|
|
||||||
let compute_characters = req.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
|
Loading…
Reference in New Issue
Block a user