mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
???
This commit is contained in:
parent
0533e67ea6
commit
bada345055
@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||
response = await flash_llama_awq.generate(
|
||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||
@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||
response = await flash_llama_awq.generate(
|
||||
"What is Deep Learning?",
|
||||
@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
|
@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||
response = await flash_llama_awq_sharded.generate(
|
||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||
@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_awq_load_sharded(
|
||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||
):
|
||||
|
@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||
response = await flash_medusa.generate(
|
||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||
@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||
response = await flash_medusa.generate(
|
||||
"What is Deep Learning?",
|
||||
@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
|
@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||
response = await flash_mistral.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||
response = await flash_mistral.generate(
|
||||
"Test request",
|
||||
@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_mistral, "Test request", max_new_tokens=10, n=4
|
||||
|
@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi(flash_phi, response_snapshot):
|
||||
response = await flash_phi.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
response = await flash_phi.generate(
|
||||
"Test request",
|
||||
@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
|
@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||
response = await flash_starcoder_gptq.generate(
|
||||
"def geometric_mean(L: List[float]):",
|
||||
@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq_default_params(
|
||||
flash_starcoder_gptq, generous_response_snapshot
|
||||
):
|
||||
@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq_load(
|
||||
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||
):
|
||||
|
@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Whats Googles DNS",
|
||||
@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"info: david holtz like trees and has two cats. ",
|
||||
@ -98,7 +95,6 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_load(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
|
||||
# this is the same as the above test, but only fires off a single request
|
||||
# this is only to ensure that the parallel and single inference produce the same result
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_single_load_instance(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
|
@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||
response = await fused_kernel_mamba.generate(
|
||||
"What is Deep Learning?", max_new_tokens=10
|
||||
@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||
response = await fused_kernel_mamba.generate(
|
||||
"blue, red, yellow, ",
|
||||
@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_mamba_load(
|
||||
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||
):
|
||||
|
@ -45,37 +45,6 @@ impl HubTokenizerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
mod json_object_or_string_to_string {
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
// A custom deserializer that treats both strings and objects as strings.
|
||||
// This provides flexibility with input formats for the 'grammar' field.
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
JSONSchema::options()
|
||||
.with_draft(Draft::Draft202012)
|
||||
.compile(&value)
|
||||
.map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(s),
|
||||
// Safely handle serialization and return an error if it fails
|
||||
Value::Object(o) => {
|
||||
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
|
||||
}
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"expected string or object for grammar",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
@ -83,12 +52,9 @@ pub(crate) enum GrammarType {
|
||||
///
|
||||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
/// with types and descriptions.
|
||||
#[serde(
|
||||
rename = "json",
|
||||
deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||
)]
|
||||
#[serde(rename = "json")]
|
||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||
Json(String),
|
||||
Json(serde_json::Value),
|
||||
#[serde(rename = "regex")]
|
||||
Regex(String),
|
||||
}
|
||||
|
@ -314,7 +314,18 @@ impl Validation {
|
||||
}
|
||||
match grammar {
|
||||
// currently both are handled the same way since compilation is done in Python
|
||||
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
|
||||
GrammarType::Json(json) => {
|
||||
// JSONSchema::options()
|
||||
// .with_draft(Draft::Draft202012)
|
||||
// .compile(&json)
|
||||
// .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
(
|
||||
serde_json::to_string(&json)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||
ProtoGrammarType::Json.into(),
|
||||
)
|
||||
}
|
||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||
}
|
||||
}
|
||||
@ -486,6 +497,8 @@ pub enum ValidationError {
|
||||
Tokenizer(String),
|
||||
#[error("grammar is not supported")]
|
||||
Grammar,
|
||||
#[error("grammar is not a valid JSONSchema: {0}")]
|
||||
InvalidGrammar(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
|
||||
scores = scores.view(B, S, -1)
|
||||
|
||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
|
||||
|
||||
for j in range(S):
|
||||
_scores = scores[:, j]
|
||||
@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
|
||||
_scores = self.repetition_processor(input_ids, _scores)
|
||||
if self.frequency_processor is not None:
|
||||
_scores = self.frequency_processor(input_ids, _scores)
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
if self.grammar_processor is not None:
|
||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
_next_ids = self.choice(_scores)
|
||||
scores[:, j] = _scores
|
||||
next_ids[:, j] = _next_ids
|
||||
|
Loading…
Reference in New Issue
Block a user