From 734d33d2232d544eb382538d8418611e131b26c7 Mon Sep 17 00:00:00 2001 From: Twilight Date: Thu, 24 Oct 2024 11:12:23 +0800 Subject: [PATCH] init --- .gitignore | 162 ++++++++++++++++++++++++++++++++++++ classification.py | 31 +++++++ main.py | 29 +++++++ protos/classification.proto | 17 ++++ 4 files changed, 239 insertions(+) create mode 100644 .gitignore create mode 100644 classification.py create mode 100644 main.py create mode 100644 protos/classification.proto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d0653a --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + diff --git a/classification.py b/classification.py new file mode 100644 index 0000000..a4b8eee --- /dev/null +++ b/classification.py @@ -0,0 +1,31 @@ +from pprint import pprint + +from pydantic import BaseModel +from transformers import pipeline + +classifier = pipeline("zero-shot-classification", + model="morit/chinese_xlm_xnli") + + +# 返回一个结构化的内容 +class ClassifyResult(BaseModel): + sequence: str + labels: list + scores: list + prediction: str + + +def classify(text: str, labels: list): + output = classifier(text, labels) + + pprint(output) + + # 根据 score,寻找最高的 label + prediction_rank = output['scores'].index(max(output['scores'])) + + return ClassifyResult( + sequence=text, + labels=output['labels'], + scores=output['scores'], + prediction=output['labels'][prediction_rank] + ) diff --git a/main.py b/main.py new file mode 100644 index 0000000..cc9f306 --- /dev/null +++ b/main.py @@ -0,0 +1,29 @@ +from typing import Union + +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel +import classification + +app = FastAPI() + +class TextClassificationRequest(BaseModel): + text: str + labels: list[str] + +class TextClassificationResponse(BaseModel): + prediction: str + labels: list[str] + + + +@app.post("/classify") +def classify(req: TextClassificationRequest) -> TextClassificationResponse: + result = classification.classify(req.text, req.labels) + + return TextClassificationResponse(prediction=result.labels[0], labels=result.labels) + + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/protos/classification.proto b/protos/classification.proto new file mode 100644 index 0000000..7dee925 --- /dev/null +++ b/protos/classification.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +service TextClassification { + rpc Classify(TextClassificationRequest) returns (TextClassificationResponse) {} +} + +message TextClassificationRequest { + string text = 1; + repeated string labels = 2; +} + +message TextClassificationResponse { + string sequence = 1; + string match = 2; + repeated string labels = 3; + repeated float scores = 4; +} \ No newline at end of file