mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
pre-commit
This commit is contained in:
parent
98e790e32a
commit
f5a6691d0e
@ -1,7 +1,7 @@
|
|||||||
from typing import List
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def bloom_560_handle(launcher):
|
def bloom_560_handle(launcher):
|
||||||
with launcher("bigscience/bloom-560m") as handle:
|
with launcher("bigscience/bloom-560m") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
@ -12,38 +12,43 @@ async def bloom_560(bloom_560_handle):
|
|||||||
await bloom_560_handle.health(240)
|
await bloom_560_handle.health(240)
|
||||||
return bloom_560_handle.client
|
return bloom_560_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bloom_560m(bloom_560):
|
async def test_bloom_560m(bloom_560):
|
||||||
|
|
||||||
base_url = bloom_560.base_url
|
base_url = bloom_560.base_url
|
||||||
prompt = "The cat sat on the mat. The cat"
|
prompt = "The cat sat on the mat. The cat"
|
||||||
|
|
||||||
repeated_2grams_control = await call_model(base_url, prompt, 0)
|
repeated_2grams_control = await call_model(base_url, prompt, 0)
|
||||||
assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case"
|
assert (
|
||||||
|
len(repeated_2grams_control) > 0
|
||||||
|
), "Expected to find repeated bi-grams in control case"
|
||||||
|
|
||||||
repeated_2grams_test = await call_model(base_url, prompt, 2)
|
repeated_2grams_test = await call_model(base_url, prompt, 2)
|
||||||
assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"
|
assert (
|
||||||
|
len(repeated_2grams_test) == 0
|
||||||
|
), f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"
|
||||||
|
|
||||||
|
|
||||||
async def call_model(base_url, prompt, n_grams):
|
async def call_model(base_url, prompt, n_grams):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters" : {
|
"parameters": {
|
||||||
"max_new_tokens": 20,
|
"max_new_tokens": 20,
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"no_repeat_ngram_size": n_grams,
|
"no_repeat_ngram_size": n_grams,
|
||||||
"details": True
|
"details": True,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
res = requests.post(f"{base_url}/generate", json=data)
|
res = requests.post(f"{base_url}/generate", json=data)
|
||||||
res = res.json()
|
res = res.json()
|
||||||
|
|
||||||
tokens = res['details']['tokens']
|
tokens = res["details"]["tokens"]
|
||||||
token_texts = [token['text'] for token in tokens]
|
token_texts = [token["text"] for token in tokens]
|
||||||
|
|
||||||
# find repeated 2grams
|
# find repeated 2grams
|
||||||
ngrams = [tuple(token_texts[i:i+2]) for i in range(len(token_texts)-2+1)]
|
ngrams = [tuple(token_texts[i : i + 2]) for i in range(len(token_texts) - 2 + 1)]
|
||||||
ngram_counts = {}
|
ngram_counts = {}
|
||||||
for ngram in ngrams:
|
for ngram in ngrams:
|
||||||
if ngram in ngram_counts:
|
if ngram in ngram_counts:
|
||||||
@ -54,4 +59,3 @@ async def call_model(base_url, prompt, n_grams):
|
|||||||
repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]
|
repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]
|
||||||
|
|
||||||
return repeated
|
return repeated
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user