From f5a6691d0e25933372fbbdbd6b26c55965b40af5 Mon Sep 17 00:00:00 2001 From: erikkaum Date: Thu, 1 Aug 2024 13:37:45 +0200 Subject: [PATCH] pre-commit --- .../models/test_no_repeat_ngram.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/integration-tests/models/test_no_repeat_ngram.py b/integration-tests/models/test_no_repeat_ngram.py index 281812f9..6a6384a1 100644 --- a/integration-tests/models/test_no_repeat_ngram.py +++ b/integration-tests/models/test_no_repeat_ngram.py @@ -1,7 +1,7 @@ -from typing import List import pytest import requests + def bloom_560_handle(launcher): with launcher("bigscience/bloom-560m") as handle: yield handle @@ -12,38 +12,43 @@ async def bloom_560(bloom_560_handle): await bloom_560_handle.health(240) return bloom_560_handle.client + @pytest.mark.release @pytest.mark.asyncio async def test_bloom_560m(bloom_560): - base_url = bloom_560.base_url + base_url = bloom_560.base_url prompt = "The cat sat on the mat. The cat" repeated_2grams_control = await call_model(base_url, prompt, 0) - assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case" + assert ( + len(repeated_2grams_control) > 0 + ), "Expected to find repeated bi-grams in control case" repeated_2grams_test = await call_model(base_url, prompt, 2) - assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}" + assert ( + len(repeated_2grams_test) == 0 + ), f"Expected no repeated bi-grams, but found: {repeated_2grams_test}" async def call_model(base_url, prompt, n_grams): data = { "inputs": prompt, - "parameters" : { + "parameters": { "max_new_tokens": 20, "seed": 42, "no_repeat_ngram_size": n_grams, - "details": True - } + "details": True, + }, } res = requests.post(f"{base_url}/generate", json=data) res = res.json() - tokens = res['details']['tokens'] - token_texts = [token['text'] for token in tokens] + tokens = res["details"]["tokens"] + token_texts = [token["text"] for token in tokens] - # find repeated 2grams - ngrams = [tuple(token_texts[i:i+2]) for i in range(len(token_texts)-2+1)] + # find repeated 2grams + ngrams = [tuple(token_texts[i : i + 2]) for i in range(len(token_texts) - 2 + 1)] ngram_counts = {} for ngram in ngrams: if ngram in ngram_counts: @@ -54,4 +59,3 @@ async def call_model(base_url, prompt, n_grams): repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1] return repeated -