mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
???
This commit is contained in:
parent
0533e67ea6
commit
bada345055
@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
"What is Deep Learning?",
|
"What is Deep Learning?",
|
||||||
@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||||
response = await flash_llama_awq_sharded.generate(
|
response = await flash_llama_awq_sharded.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_load_sharded(
|
async def test_flash_llama_awq_load_sharded(
|
||||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||||
response = await flash_medusa.generate(
|
response = await flash_medusa.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||||
response = await flash_medusa.generate(
|
response = await flash_medusa.generate(
|
||||||
"What is Deep Learning?",
|
"What is Deep Learning?",
|
||||||
@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_mistral(flash_mistral, response_snapshot):
|
async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||||
response = await flash_mistral.generate(
|
response = await flash_mistral.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||||
response = await flash_mistral.generate(
|
response = await flash_mistral.generate(
|
||||||
"Test request",
|
"Test request",
|
||||||
@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_mistral, "Test request", max_new_tokens=10, n=4
|
flash_mistral, "Test request", max_new_tokens=10, n=4
|
||||||
|
@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_phi(flash_phi, response_snapshot):
|
async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
"Test request",
|
"Test request",
|
||||||
@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
flash_starcoder_gptq, generous_response_snapshot
|
flash_starcoder_gptq, generous_response_snapshot
|
||||||
):
|
):
|
||||||
@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_starcoder_gptq_load(
|
async def test_flash_starcoder_gptq_load(
|
||||||
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||||
response = await flash_llama_grammar.generate(
|
response = await flash_llama_grammar.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||||
response = await flash_llama_grammar.generate(
|
response = await flash_llama_grammar.generate(
|
||||||
"Whats Googles DNS",
|
"Whats Googles DNS",
|
||||||
@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||||
response = await flash_llama_grammar.generate(
|
response = await flash_llama_grammar.generate(
|
||||||
"info: david holtz like trees and has two cats. ",
|
"info: david holtz like trees and has two cats. ",
|
||||||
@ -98,7 +95,6 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_load(
|
async def test_flash_llama_grammar_load(
|
||||||
flash_llama_grammar, generate_load, response_snapshot
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
|
|||||||
# this is the same as the above test, but only fires off a single request
|
# this is the same as the above test, but only fires off a single request
|
||||||
# this is only to ensure that the parallel and single inference produce the same result
|
# this is only to ensure that the parallel and single inference produce the same result
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_single_load_instance(
|
async def test_flash_llama_grammar_single_load_instance(
|
||||||
flash_llama_grammar, generate_load, response_snapshot
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10
|
"What is Deep Learning?", max_new_tokens=10
|
||||||
@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"blue, red, yellow, ",
|
"blue, red, yellow, ",
|
||||||
@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_mamba_load(
|
async def test_mamba_load(
|
||||||
fused_kernel_mamba, generate_load, generous_response_snapshot
|
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -45,37 +45,6 @@ impl HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod json_object_or_string_to_string {
|
|
||||||
use jsonschema::{Draft, JSONSchema};
|
|
||||||
use serde::{Deserialize, Deserializer};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
// A custom deserializer that treats both strings and objects as strings.
|
|
||||||
// This provides flexibility with input formats for the 'grammar' field.
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Value::deserialize(deserializer)?;
|
|
||||||
|
|
||||||
JSONSchema::options()
|
|
||||||
.with_draft(Draft::Draft202012)
|
|
||||||
.compile(&value)
|
|
||||||
.map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?;
|
|
||||||
|
|
||||||
match value {
|
|
||||||
Value::String(s) => Ok(s),
|
|
||||||
// Safely handle serialization and return an error if it fails
|
|
||||||
Value::Object(o) => {
|
|
||||||
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
|
|
||||||
}
|
|
||||||
_ => Err(serde::de::Error::custom(
|
|
||||||
"expected string or object for grammar",
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
@ -83,12 +52,9 @@ pub(crate) enum GrammarType {
|
|||||||
///
|
///
|
||||||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||||
/// with types and descriptions.
|
/// with types and descriptions.
|
||||||
#[serde(
|
#[serde(rename = "json")]
|
||||||
rename = "json",
|
|
||||||
deserialize_with = "json_object_or_string_to_string::deserialize"
|
|
||||||
)]
|
|
||||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||||
Json(String),
|
Json(serde_json::Value),
|
||||||
#[serde(rename = "regex")]
|
#[serde(rename = "regex")]
|
||||||
Regex(String),
|
Regex(String),
|
||||||
}
|
}
|
||||||
|
@ -314,7 +314,18 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
match grammar {
|
match grammar {
|
||||||
// currently both are handled the same way since compilation is done in Python
|
// currently both are handled the same way since compilation is done in Python
|
||||||
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
|
GrammarType::Json(json) => {
|
||||||
|
// JSONSchema::options()
|
||||||
|
// .with_draft(Draft::Draft202012)
|
||||||
|
// .compile(&json)
|
||||||
|
// .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
|
(
|
||||||
|
serde_json::to_string(&json)
|
||||||
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||||
|
ProtoGrammarType::Json.into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -486,6 +497,8 @@ pub enum ValidationError {
|
|||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
#[error("grammar is not supported")]
|
#[error("grammar is not supported")]
|
||||||
Grammar,
|
Grammar,
|
||||||
|
#[error("grammar is not a valid JSONSchema: {0}")]
|
||||||
|
InvalidGrammar(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = scores.view(B, S, -1)
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
|
|
||||||
|
|
||||||
for j in range(S):
|
for j in range(S):
|
||||||
_scores = scores[:, j]
|
_scores = scores[:, j]
|
||||||
@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
_scores = self.repetition_processor(input_ids, _scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
_scores = self.frequency_processor(input_ids, _scores)
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
for warper in self.warpers:
|
|
||||||
_scores = warper(input_ids, _scores)
|
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||||
|
for warper in self.warpers:
|
||||||
|
_scores = warper(input_ids, _scores)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
|
Loading…
Reference in New Issue
Block a user