mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
fix(server): Fix stop sequences (#11)
This commit is contained in:
parent
3e2e6240b8
commit
611e21cb13
@ -1,13 +1,13 @@
|
|||||||
use std::fs::File;
|
use float_eq::assert_float_eq;
|
||||||
|
use serde::Deserialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use std::fs::File;
|
||||||
use std::io::{BufRead, BufReader};
|
use std::io::{BufRead, BufReader};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use float_eq::assert_float_eq;
|
|
||||||
use subprocess::{Popen, PopenConfig, Redirection};
|
use subprocess::{Popen, PopenConfig, Redirection};
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Details {
|
struct Details {
|
||||||
@ -22,7 +22,6 @@ struct GeneratedText {
|
|||||||
details: Details,
|
details: Details,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||||
let argv = vec![
|
let argv = vec![
|
||||||
"text-generation-launcher".to_string(),
|
"text-generation-launcher".to_string(),
|
||||||
@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for _ in 0..30 {
|
for _ in 0..60 {
|
||||||
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
|
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
|
||||||
if health.is_ok() {
|
if health.is_ok() {
|
||||||
return launcher;
|
return launcher;
|
||||||
@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||||||
panic!("failed to launch {}", model_name)
|
panic!("failed to launch {}", model_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
|
fn test_model(
|
||||||
|
model_name: String,
|
||||||
|
num_shard: usize,
|
||||||
|
port: usize,
|
||||||
|
master_port: usize,
|
||||||
|
) -> GeneratedText {
|
||||||
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
||||||
|
|
||||||
let data = r#"
|
let data = r#"
|
||||||
@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
|
|||||||
results.pop().unwrap()
|
results.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn read_json(name: &str) -> GeneratedText {
|
fn read_json(name: &str) -> GeneratedText {
|
||||||
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||||
d.push("tests/");
|
d.push("tests/");
|
||||||
@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
|
|||||||
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||||
assert_eq!(result.generated_text, expected.generated_text);
|
assert_eq!(result.generated_text, expected.generated_text);
|
||||||
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
|
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
|
||||||
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
|
assert_eq!(
|
||||||
|
result.details.generated_tokens,
|
||||||
|
expected.details.generated_tokens
|
||||||
|
);
|
||||||
|
|
||||||
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
|
for (token, expected_token) in result
|
||||||
|
.details
|
||||||
|
.tokens
|
||||||
|
.into_iter()
|
||||||
|
.zip(expected.details.tokens.into_iter())
|
||||||
|
{
|
||||||
assert_eq!(token.0, expected_token.0);
|
assert_eq!(token.0, expected_token.0);
|
||||||
assert_eq!(token.1, expected_token.1);
|
assert_eq!(token.1, expected_token.1);
|
||||||
if let Some(logprob) = token.2 {
|
if let Some(logprob) = token.2 {
|
||||||
|
@ -11,46 +11,33 @@ from text_generation.utils import (
|
|||||||
|
|
||||||
|
|
||||||
def test_stop_sequence_criteria():
|
def test_stop_sequence_criteria():
|
||||||
criteria = StopSequenceCriteria([1, 2, 3])
|
criteria = StopSequenceCriteria("/test;")
|
||||||
|
|
||||||
assert not criteria(1)
|
assert not criteria("/")
|
||||||
assert criteria.current_token_idx == 1
|
assert not criteria("/test")
|
||||||
assert not criteria(2)
|
assert criteria("/test;")
|
||||||
assert criteria.current_token_idx == 2
|
assert not criteria("/test; ")
|
||||||
assert criteria(3)
|
|
||||||
assert criteria.current_token_idx == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_stop_sequence_criteria_reset():
|
|
||||||
criteria = StopSequenceCriteria([1, 2, 3])
|
|
||||||
|
|
||||||
assert not criteria(1)
|
|
||||||
assert criteria.current_token_idx == 1
|
|
||||||
assert not criteria(2)
|
|
||||||
assert criteria.current_token_idx == 2
|
|
||||||
assert not criteria(4)
|
|
||||||
assert criteria.current_token_idx == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_stop_sequence_criteria_empty():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
StopSequenceCriteria([])
|
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria():
|
def test_stopping_criteria():
|
||||||
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria([1]) == (False, None)
|
assert criteria(65827, "/test") == (False, None)
|
||||||
assert criteria([1, 2]) == (False, None)
|
assert criteria(30, ";") == (True, "stop_sequence")
|
||||||
assert criteria([1, 2, 3]) == (True, "stop_sequence")
|
|
||||||
|
|
||||||
|
def test_stopping_criteria_eos():
|
||||||
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
|
assert criteria(1, "") == (False, None)
|
||||||
|
assert criteria(0, "") == (True, "eos_token")
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria_max():
|
def test_stopping_criteria_max():
|
||||||
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria([1]) == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria([1, 1]) == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria([1, 1, 1]) == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria([1, 1, 1, 1]) == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
|
assert criteria(1, "") == (True, "length")
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
|
@ -345,7 +345,12 @@ class CausalLM(Model):
|
|||||||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(all_input_ids)
|
stop, reason = stopping_criteria(
|
||||||
|
next_token.squeeze(),
|
||||||
|
self.tokenizer.decode(
|
||||||
|
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||||
|
),
|
||||||
|
)
|
||||||
if stop:
|
if stop:
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output_text = self.tokenizer.decode(
|
output_text = self.tokenizer.decode(
|
||||||
|
@ -441,7 +441,12 @@ class Seq2SeqLM(Model):
|
|||||||
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(decoder_input_ids)
|
stop, reason = stopping_criteria(
|
||||||
|
next_token.squeeze(),
|
||||||
|
self.tokenizer.decode(
|
||||||
|
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||||
|
),
|
||||||
|
)
|
||||||
if stop:
|
if stop:
|
||||||
# Slice with decoder_input_length to remove padding
|
# Slice with decoder_input_length to remove padding
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import concurrent
|
import concurrent
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -74,43 +75,39 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
|
|
||||||
class StopSequenceCriteria:
|
class StopSequenceCriteria:
|
||||||
def __init__(self, tokens: List[int]):
|
def __init__(self, stop_sequence: str):
|
||||||
if not tokens:
|
self.regex = re.compile(f".*{stop_sequence}$")
|
||||||
raise ValueError("tokens cannot be empty")
|
|
||||||
|
|
||||||
self.tokens = tokens
|
def __call__(self, output: str) -> bool:
|
||||||
self.current_token_idx = 0
|
if self.regex.findall(output):
|
||||||
|
|
||||||
def __call__(self, last_token: int) -> bool:
|
|
||||||
if last_token == self.tokens[self.current_token_idx]:
|
|
||||||
# Increase idx to go to next token
|
|
||||||
self.current_token_idx += 1
|
|
||||||
else:
|
|
||||||
# Reset to first token of the stopping sequence
|
|
||||||
self.current_token_idx = 0
|
|
||||||
|
|
||||||
if self.current_token_idx == len(self.tokens):
|
|
||||||
# We matched the entire sequence without resetting
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteria:
|
class StoppingCriteria:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
|
self,
|
||||||
|
eos_token_id: int,
|
||||||
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
|
max_new_tokens=20,
|
||||||
):
|
):
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
|
self.current_output = ""
|
||||||
|
|
||||||
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, "length"
|
return True, "length"
|
||||||
|
|
||||||
last_token = all_ids[-1]
|
if last_token == self.eos_token_id:
|
||||||
|
return True, "eos_token"
|
||||||
|
|
||||||
|
self.current_output += last_output
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
if stop_sequence_criteria(last_token):
|
if stop_sequence_criteria(self.current_output):
|
||||||
return True, "stop_sequence"
|
return True, "stop_sequence"
|
||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
@ -119,16 +116,12 @@ class StoppingCriteria:
|
|||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = []
|
stop_sequence_criterias = [
|
||||||
for stop_sequence in pb.stop_sequences:
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
tokens = tokenizer(
|
]
|
||||||
stop_sequence, padding=False, return_attention_mask=False
|
return StoppingCriteria(
|
||||||
).input_ids
|
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
||||||
if tokens:
|
)
|
||||||
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
|
|
||||||
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
|
|
||||||
|
|
||||||
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
def initialize_torch_distributed():
|
||||||
|
Loading…
Reference in New Issue
Block a user