This commit is contained in:
OlivierDehaene 2024-02-20 14:55:30 +01:00
parent 0533e67ea6
commit bada345055
11 changed files with 18 additions and 65 deletions

View File

@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
"What is Deep Learning?",
@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load(
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot
):

View File

@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?",
@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request",
@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
responses = await generate_load(
flash_mistral, "Test request", max_new_tokens=10, n=4

View File

@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request",
@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)

View File

@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):",
@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_default_params(
flash_starcoder_gptq, generous_response_snapshot
):
@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_load(
flash_starcoder_gptq, generate_load, generous_response_snapshot
):

View File

@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
@ -98,7 +95,6 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):

View File

@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"What is Deep Learning?", max_new_tokens=10
@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"blue, red, yellow, ",
@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot
):

View File

@ -45,37 +45,6 @@ impl HubTokenizerConfig {
}
}
mod json_object_or_string_to_string {
use jsonschema::{Draft, JSONSchema};
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
JSONSchema::options()
.with_draft(Draft::Draft202012)
.compile(&value)
.map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?;
match value {
Value::String(s) => Ok(s),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
@ -83,12 +52,9 @@ pub(crate) enum GrammarType {
///
/// JSON Schema is a declarative language that allows to annotate JSON documents
/// with types and descriptions.
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
#[serde(rename = "json")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(String),
Json(serde_json::Value),
#[serde(rename = "regex")]
Regex(String),
}

View File

@ -314,7 +314,18 @@ impl Validation {
}
match grammar {
// currently both are handled the same way since compilation is done in Python
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
GrammarType::Json(json) => {
// JSONSchema::options()
// .with_draft(Draft::Draft202012)
// .compile(&json)
// .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(
serde_json::to_string(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
ProtoGrammarType::Json.into(),
)
}
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
}
@ -486,6 +497,8 @@ pub enum ValidationError {
Tokenizer(String),
#[error("grammar is not supported")]
Grammar,
#[error("grammar is not a valid JSONSchema: {0}")]
InvalidGrammar(String),
}
#[cfg(test)]

View File

@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S):
_scores = scores[:, j]
@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
_scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids