♻️ refactor: provider refactor (#41)
* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
parent
0bfe1f5779
commit
ef041e28a1
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,5 +8,5 @@ build
|
||||
logs
|
||||
data
|
||||
tmp/
|
||||
test/
|
||||
/test/
|
||||
.env
|
299
common/client.go
299
common/client.go
@ -1,299 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
var clientPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &http.Client{}
|
||||
},
|
||||
}
|
||||
|
||||
func GetHttpClient(proxyAddr string) *http.Client {
|
||||
client := clientPool.Get().(*http.Client)
|
||||
|
||||
if RelayTimeout > 0 {
|
||||
client.Timeout = time.Duration(RelayTimeout) * time.Second
|
||||
}
|
||||
|
||||
if proxyAddr != "" {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
SysError("Error parsing proxy address: " + err.Error())
|
||||
return client
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
SysError("Error creating SOCKS5 dialer: " + err.Error())
|
||||
return client
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
default:
|
||||
SysError("Unsupported proxy scheme: " + proxyURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
}
|
||||
|
||||
func PutHttpClient(c *http.Client) {
|
||||
clientPool.Put(c)
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
requestBuilder RequestBuilder
|
||||
CreateFormBuilder func(io.Writer) FormBuilder
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
requestBuilder: NewRequestBuilder(),
|
||||
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
||||
return NewFormBuilder(body)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type requestOptions struct {
|
||||
body any
|
||||
header http.Header
|
||||
}
|
||||
|
||||
type requestOption func(*requestOptions)
|
||||
|
||||
type Stringer interface {
|
||||
GetString() *string
|
||||
}
|
||||
|
||||
func WithBody(body any) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.body = body
|
||||
}
|
||||
}
|
||||
|
||||
func WithHeader(header map[string]string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
for k, v := range header {
|
||||
args.header.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentType(contentType string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.header.Set("Content-Type", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
type RequestError struct {
|
||||
HTTPStatusCode int
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||||
// Default Options
|
||||
args := &requestOptions{
|
||||
body: nil,
|
||||
header: make(http.Header),
|
||||
}
|
||||
for _, setter := range setters {
|
||||
setter(args)
|
||||
}
|
||||
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func SendRequest(req *http.Request, response any, outputResp bool, proxyAddr string) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
client := GetHttpClient(proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
PutHttpClient(client)
|
||||
|
||||
if !outputResp {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
if IsFailureStatusCode(resp) {
|
||||
return nil, HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
if outputResp {
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
err = DecodeResponse(tee, response)
|
||||
|
||||
// 将响应体重新写入 resp.Body
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
} else {
|
||||
err = DecodeResponse(resp.Body, response)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if outputResp {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error types.OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// var errorResponse types.OpenAIErrorResponse
|
||||
var errorResponse GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errorResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if errorResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
|
||||
} else {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage()
|
||||
}
|
||||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) SendRequestRaw(req *http.Request, proxyAddr string) (body io.ReadCloser, err error) {
|
||||
client := GetHttpClient(proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
PutHttpClient(client)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func IsFailureStatusCode(resp *http.Response) bool {
|
||||
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
func DecodeResponse(body io.Reader, v any) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result, ok := v.(*string); ok {
|
||||
return DecodeString(body, result)
|
||||
}
|
||||
|
||||
if stringer, ok := v.(Stringer); ok {
|
||||
return DecodeString(body, stringer.GetString())
|
||||
}
|
||||
|
||||
return json.NewDecoder(body).Decode(v)
|
||||
}
|
||||
|
||||
func DecodeString(body io.Reader, output *string) error {
|
||||
b, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*output = string(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
@ -37,6 +37,14 @@ func ErrorWrapper(err error, code string, statusCode int) *types.OpenAIErrorWith
|
||||
return StringErrorWrapper(err.Error(), code, statusCode)
|
||||
}
|
||||
|
||||
func ErrorToOpenAIError(err error) *types.OpenAIError {
|
||||
return &types.OpenAIError{
|
||||
Code: "system error",
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
}
|
||||
}
|
||||
|
||||
func StringErrorWrapper(err string, code string, statusCode int) *types.OpenAIErrorWithStatusCode {
|
||||
openAIError := types.OpenAIError{
|
||||
Message: err,
|
||||
|
@ -1,59 +0,0 @@
|
||||
package common
|
||||
|
||||
// type Quota struct {
|
||||
// ModelName string
|
||||
// ModelRatio float64
|
||||
// GroupRatio float64
|
||||
// Ratio float64
|
||||
// UserQuota int
|
||||
// }
|
||||
|
||||
// func CreateQuota(modelName string, userQuota int, group string) *Quota {
|
||||
// modelRatio := GetModelRatio(modelName)
|
||||
// groupRatio := GetGroupRatio(group)
|
||||
|
||||
// return &Quota{
|
||||
// ModelName: modelName,
|
||||
// ModelRatio: modelRatio,
|
||||
// GroupRatio: groupRatio,
|
||||
// Ratio: modelRatio * groupRatio,
|
||||
// UserQuota: userQuota,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
// if ApproximateTokenEnabled {
|
||||
// return int(float64(len(text)) * 0.38)
|
||||
// }
|
||||
// return len(tokenEncoder.Encode(text, nil, nil))
|
||||
// }
|
||||
|
||||
// func (q *Quota) CountTokenMessages(messages []Message, model string) int {
|
||||
// tokenEncoder := q.getTokenEncoder(model)
|
||||
// // Reference:
|
||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
// //
|
||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
// var tokensPerMessage int
|
||||
// var tokensPerName int
|
||||
// if model == "gpt-3.5-turbo-0301" {
|
||||
// tokensPerMessage = 4
|
||||
// tokensPerName = -1 // If there's a name, the role is omitted
|
||||
// } else {
|
||||
// tokensPerMessage = 3
|
||||
// tokensPerName = 1
|
||||
// }
|
||||
// tokenNum := 0
|
||||
// for _, message := range messages {
|
||||
// tokenNum += tokensPerMessage
|
||||
// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent())
|
||||
// tokenNum += q.getTokenNum(tokenEncoder, message.Role)
|
||||
// if message.Name != nil {
|
||||
// tokenNum += tokensPerName
|
||||
// tokenNum += q.getTokenNum(tokenEncoder, *message.Name)
|
||||
// }
|
||||
// }
|
||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
// return tokenNum
|
||||
// }
|
@ -1,4 +1,4 @@
|
||||
package common
|
||||
package requester
|
||||
|
||||
import (
|
||||
"fmt"
|
68
common/requester/http_client.go
Normal file
68
common/requester/http_client.go
Normal file
@ -0,0 +1,68 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type HTTPClient struct{}
|
||||
|
||||
var clientPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &http.Client{}
|
||||
},
|
||||
}
|
||||
|
||||
func (h *HTTPClient) getClientFromPool(proxyAddr string) *http.Client {
|
||||
client := clientPool.Get().(*http.Client)
|
||||
|
||||
if common.RelayTimeout > 0 {
|
||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||
}
|
||||
|
||||
if proxyAddr != "" {
|
||||
err := h.setProxy(client, proxyAddr)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (h *HTTPClient) returnClientToPool(client *http.Client) {
|
||||
clientPool.Put(client)
|
||||
}
|
||||
|
||||
func (h *HTTPClient) setProxy(client *http.Client, proxyAddr string) error {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing proxy address: %w", err)
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating socks5 dialer: %w", err)
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
229
common/requester/http_requester.go
Normal file
229
common/requester/http_requester.go
Normal file
@ -0,0 +1,229 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type HttpErrorHandler func(*http.Response) *types.OpenAIError
|
||||
|
||||
type HTTPRequester struct {
|
||||
HTTPClient HTTPClient
|
||||
requestBuilder RequestBuilder
|
||||
CreateFormBuilder func(io.Writer) FormBuilder
|
||||
ErrorHandler HttpErrorHandler
|
||||
proxyAddr string
|
||||
}
|
||||
|
||||
// NewHTTPRequester 创建一个新的 HTTPRequester 实例。
|
||||
// proxyAddr: 是代理服务器的地址。
|
||||
// errorHandler: 是一个错误处理函数,它接收一个 *http.Response 参数并返回一个 *types.OpenAIErrorResponse。
|
||||
// 如果 errorHandler 为 nil,那么会使用一个默认的错误处理函数。
|
||||
func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequester {
|
||||
return &HTTPRequester{
|
||||
HTTPClient: HTTPClient{},
|
||||
requestBuilder: NewRequestBuilder(),
|
||||
CreateFormBuilder: func(body io.Writer) FormBuilder {
|
||||
return NewFormBuilder(body)
|
||||
},
|
||||
ErrorHandler: errorHandler,
|
||||
proxyAddr: proxyAddr,
|
||||
}
|
||||
}
|
||||
|
||||
type requestOptions struct {
|
||||
body any
|
||||
header http.Header
|
||||
}
|
||||
|
||||
type requestOption func(*requestOptions)
|
||||
|
||||
// 创建请求
|
||||
func (r *HTTPRequester) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||||
args := &requestOptions{
|
||||
body: nil,
|
||||
header: make(http.Header),
|
||||
}
|
||||
for _, setter := range setters {
|
||||
setter(args)
|
||||
}
|
||||
req, err := r.requestBuilder.Build(method, url, args.body, args.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp bool) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
r.HTTPClient.returnClientToPool(client)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if !outputResp {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
if r.IsFailureStatusCode(resp) {
|
||||
return nil, HandleErrorResp(resp, r.ErrorHandler)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
if outputResp {
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
err = DecodeResponse(tee, response)
|
||||
|
||||
// 将响应体重新写入 resp.Body
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
} else {
|
||||
err = json.NewDecoder(resp.Body).Decode(response)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// 发送请求 RAW
|
||||
func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
// 发送请求
|
||||
client := r.HTTPClient.getClientFromPool(r.proxyAddr)
|
||||
resp, err := client.Do(req)
|
||||
r.HTTPClient.returnClientToPool(client)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
if r.IsFailureStatusCode(resp) {
|
||||
return nil, HandleErrorResp(resp, r.ErrorHandler)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// 获取流式响应
|
||||
func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) {
|
||||
// 如果返回的头是json格式 说明有错误
|
||||
if resp.Header.Get("Content-Type") == "application/json" {
|
||||
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
||||
}
|
||||
|
||||
return &streamReader[T]{
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置请求体
|
||||
func (r *HTTPRequester) WithBody(body any) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.body = body
|
||||
}
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
func (r *HTTPRequester) WithHeader(header map[string]string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
for k, v := range header {
|
||||
args.header.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置Content-Type
|
||||
func (r *HTTPRequester) WithContentType(contentType string) requestOption {
|
||||
return func(args *requestOptions) {
|
||||
args.header.Set("Content-Type", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
// 判断是否为失败状态码
|
||||
func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool {
|
||||
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if toOpenAIError != nil {
|
||||
errorResponse := toOpenAIError(resp)
|
||||
|
||||
if errorResponse != nil && errorResponse.Message != "" {
|
||||
openAIErrorWithStatusCode.OpenAIError = *errorResponse
|
||||
}
|
||||
}
|
||||
|
||||
if openAIErrorWithStatusCode.OpenAIError.Message == "" {
|
||||
openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return openAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
type Stringer interface {
|
||||
GetString() *string
|
||||
}
|
||||
|
||||
func DecodeResponse(body io.Reader, v any) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result, ok := v.(*string); ok {
|
||||
return DecodeString(body, result)
|
||||
}
|
||||
|
||||
if stringer, ok := v.(Stringer); ok {
|
||||
return DecodeString(body, stringer.GetString())
|
||||
}
|
||||
|
||||
return json.NewDecoder(body).Decode(v)
|
||||
}
|
||||
|
||||
func DecodeString(body io.Reader, output *string) error {
|
||||
b, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*output = string(b)
|
||||
return nil
|
||||
}
|
79
common/requester/http_stream_reader.go
Normal file
79
common/requester/http_stream_reader.go
Normal file
@ -0,0 +1,79 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 流处理函数,判断依据如下:
|
||||
// 1.如果有错误信息,则直接返回错误信息
|
||||
// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response
|
||||
// 3.如果rawLine=nil 或者 response长度为0,则直接跳过
|
||||
// 4.如果以上条件都不满足,则返回response
|
||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
|
||||
|
||||
type streamable interface {
|
||||
// types.ChatCompletionStreamResponse | types.CompletionResponse
|
||||
any
|
||||
}
|
||||
|
||||
type StreamReaderInterface[T streamable] interface {
|
||||
Recv() (*[]T, error)
|
||||
Close()
|
||||
}
|
||||
|
||||
type streamReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (stream *streamReader[T]) processLines() (*[]T, error) {
|
||||
for {
|
||||
rawLine, readErr := stream.reader.ReadBytes('\n')
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||
|
||||
var response []T
|
||||
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if noSpaceLine == nil || len(response) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Close() {
|
||||
stream.response.Body.Close()
|
||||
}
|
@ -1,9 +1,10 @@
|
||||
package common
|
||||
package requester
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type RequestBuilder interface {
|
||||
@ -11,12 +12,12 @@ type RequestBuilder interface {
|
||||
}
|
||||
|
||||
type HTTPRequestBuilder struct {
|
||||
marshaller Marshaller
|
||||
marshaller common.Marshaller
|
||||
}
|
||||
|
||||
func NewRequestBuilder() *HTTPRequestBuilder {
|
||||
return &HTTPRequestBuilder{
|
||||
marshaller: &JSONMarshaller{},
|
||||
marshaller: &common.JSONMarshaller{},
|
||||
}
|
||||
}
|
||||
|
53
common/requester/ws_client.go
Normal file
53
common/requester/ws_client.go
Normal file
@ -0,0 +1,53 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func GetWSClient(proxyAddr string) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
if proxyAddr != "" {
|
||||
err := setWSProxy(dialer, proxyAddr)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
return dialer
|
||||
}
|
||||
}
|
||||
|
||||
return dialer
|
||||
}
|
||||
|
||||
func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing proxy address: %w", err)
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
dialer.Proxy = http.ProxyURL(proxyURL)
|
||||
case "socks5":
|
||||
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating socks5 dialer: %w", err)
|
||||
}
|
||||
dialer.NetDial = func(network, addr string) (net.Conn, error) {
|
||||
return socks5Proxy.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
58
common/requester/ws_reader.go
Normal file
58
common/requester/ws_reader.go
Normal file
@ -0,0 +1,58 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type wsReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *websocket.Conn
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) processLines() (*[]T, error) {
|
||||
for {
|
||||
_, msg, err := stream.reader.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response []T
|
||||
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if msg == nil || len(response) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) Close() {
|
||||
stream.reader.Close()
|
||||
}
|
54
common/requester/ws_requester.go
Normal file
54
common/requester/ws_requester.go
Normal file
@ -0,0 +1,54 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WSRequester struct {
|
||||
WSClient *websocket.Dialer
|
||||
}
|
||||
|
||||
func NewWSRequester(proxyAddr string) *WSRequester {
|
||||
return &WSRequester{
|
||||
WSClient: GetWSClient(proxyAddr),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WSRequester) NewRequest(url string, header http.Header) (*websocket.Conn, error) {
|
||||
conn, resp, err := w.WSClient.Dial(url, header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, errors.New("ws unexpected status code")
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPrefix HandlerPrefix[T]) (*wsReader[T], *types.OpenAIErrorWithStatusCode) {
|
||||
err := conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return &wsReader[T]{
|
||||
reader: conn,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
func (r *WSRequester) WithHeader(headers map[string]string) http.Header {
|
||||
header := make(http.Header)
|
||||
for k, v := range headers {
|
||||
header.Set(k, v)
|
||||
}
|
||||
return header
|
||||
}
|
55
common/test/api.go
Normal file
55
common/test/api.go
Normal file
@ -0,0 +1,55 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RequestJSONConfig() map[string]string {
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
func GetContext(method, path string, headers map[string]string, body io.Reader) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
var req *http.Request
|
||||
req, _ = http.NewRequest(method, path, body)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
return c, w
|
||||
}
|
||||
|
||||
func GetGinRouter(method, path string, headers map[string]string, body *io.Reader) *httptest.ResponseRecorder {
|
||||
var req *http.Request
|
||||
r := gin.Default()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ = http.NewRequest(method, path, *body)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func GetChannel(channelType int, baseUrl, other, porxy, modelMapping string) model.Channel {
|
||||
return model.Channel{
|
||||
Type: channelType,
|
||||
BaseURL: &baseUrl,
|
||||
Other: other,
|
||||
Proxy: porxy,
|
||||
ModelMapping: &modelMapping,
|
||||
Key: GetTestToken(),
|
||||
}
|
||||
}
|
132
common/test/chat_config.go
Normal file
132
common/test/chat_config.go
Normal file
@ -0,0 +1,132 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetChatCompletionRequest(chatType, modelName, stream string) *types.ChatCompletionRequest {
|
||||
chatJSON := GetChatRequest(chatType, modelName, stream)
|
||||
chatCompletionRequest := &types.ChatCompletionRequest{}
|
||||
json.NewDecoder(chatJSON).Decode(chatCompletionRequest)
|
||||
return chatCompletionRequest
|
||||
}
|
||||
|
||||
func GetChatRequest(chatType, modelName, stream string) *strings.Reader {
|
||||
var chatJSON string
|
||||
switch chatType {
|
||||
case "image":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
"stream": ` + stream + `
|
||||
}`
|
||||
case "default":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
],
|
||||
"stream": ` + stream + `
|
||||
}`
|
||||
case "function":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"stream": ` + stream + `,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Boston?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "auto"
|
||||
}`
|
||||
|
||||
case "tools":
|
||||
chatJSON = `{
|
||||
"model": "` + modelName + `",
|
||||
"stream": ` + stream + `,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Boston?"
|
||||
}
|
||||
],
|
||||
"functions": [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
}
|
||||
|
||||
return strings.NewReader(chatJSON)
|
||||
}
|
65
common/test/check_chat.go
Normal file
65
common/test/check_chat.go
Normal file
@ -0,0 +1,65 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"one-api/types"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func CheckChat(t *testing.T, response *types.ChatCompletionResponse, modelName string, usage *types.Usage) {
|
||||
assert.NotEmpty(t, response.ID)
|
||||
assert.NotEmpty(t, response.Object)
|
||||
assert.NotEmpty(t, response.Created)
|
||||
assert.Equal(t, response.Model, modelName)
|
||||
assert.IsType(t, []types.ChatCompletionChoice{}, response.Choices)
|
||||
// check choices 长度大于1
|
||||
assert.True(t, len(response.Choices) > 0)
|
||||
for _, choice := range response.Choices {
|
||||
assert.NotNil(t, choice.Index)
|
||||
assert.IsType(t, types.ChatCompletionMessage{}, choice.Message)
|
||||
assert.NotEmpty(t, choice.Message.Role)
|
||||
assert.NotEmpty(t, choice.FinishReason)
|
||||
|
||||
// check message
|
||||
if choice.Message.Content != nil {
|
||||
multiContents, ok := choice.Message.Content.([]types.ChatMessagePart)
|
||||
if ok {
|
||||
for _, content := range multiContents {
|
||||
assert.NotEmpty(t, content.Type)
|
||||
if content.Type == "text" {
|
||||
assert.NotEmpty(t, content.Text)
|
||||
} else if content.Type == "image_url" {
|
||||
assert.IsType(t, types.ChatMessageImageURL{}, content.ImageURL)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content, ok := choice.Message.Content.(string)
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, content)
|
||||
}
|
||||
} else if choice.Message.FunctionCall != nil {
|
||||
assert.NotEmpty(t, choice.Message.FunctionCall.Name)
|
||||
assert.Equal(t, choice.FinishReason, types.FinishReasonFunctionCall)
|
||||
} else if choice.Message.ToolCalls != nil {
|
||||
assert.IsType(t, []types.ChatCompletionToolCalls{}, choice.Message.ToolCalls)
|
||||
assert.NotEmpty(t, choice.Message.ToolCalls[0].Id)
|
||||
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function)
|
||||
assert.Equal(t, choice.Message.ToolCalls[0].Function, "function")
|
||||
|
||||
assert.IsType(t, types.ChatCompletionToolCallsFunction{}, choice.Message.ToolCalls[0].Function)
|
||||
assert.NotEmpty(t, choice.Message.ToolCalls[0].Function.Name)
|
||||
|
||||
assert.Equal(t, choice.FinishReason, types.FinishReasonToolCalls)
|
||||
} else {
|
||||
assert.Fail(t, "message content is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// check usage
|
||||
assert.IsType(t, &types.Usage{}, response.Usage)
|
||||
assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens)
|
||||
assert.Equal(t, response.Usage.CompletionTokens, usage.CompletionTokens)
|
||||
assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens)
|
||||
|
||||
}
|
48
common/test/checks.go
Normal file
48
common/test/checks.go
Normal file
@ -0,0 +1,48 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func NoError(t *testing.T, err error, message ...string) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Error(err, message)
|
||||
}
|
||||
}
|
||||
|
||||
func HasError(t *testing.T, err error, message ...string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Error(err, message)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIs(t *testing.T, err, target error, msg ...string) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, target) {
|
||||
t.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, target) {
|
||||
t.Fatalf(format, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIsNot(t *testing.T, err, target error, msg ...string) {
|
||||
t.Helper()
|
||||
if errors.Is(err, target) {
|
||||
t.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) {
|
||||
t.Helper()
|
||||
if errors.Is(err, target) {
|
||||
t.Fatalf(format, msg)
|
||||
}
|
||||
}
|
7
common/test/init/init.go
Normal file
7
common/test/init/init.go
Normal file
@ -0,0 +1,7 @@
|
||||
package init
|
||||
|
||||
import "testing"
|
||||
|
||||
func init() {
|
||||
testing.Init()
|
||||
}
|
63
common/test/server.go
Normal file
63
common/test/server.go
Normal file
@ -0,0 +1,63 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const testAPI = "this-is-my-secure-token-do-not-steal!!"
|
||||
|
||||
func GetTestToken() string {
|
||||
return testAPI
|
||||
}
|
||||
|
||||
type ServerTest struct {
|
||||
handlers map[string]handler
|
||||
}
|
||||
type handler func(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
func NewTestServer() *ServerTest {
|
||||
return &ServerTest{handlers: make(map[string]handler)}
|
||||
}
|
||||
|
||||
func OpenAICheck(w http.ResponseWriter, r *http.Request) bool {
|
||||
if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *ServerTest) RegisterHandler(path string, handler handler) {
|
||||
// to make the registered paths friendlier to a regex match in the route handler
|
||||
// in OpenAITestServer
|
||||
path = strings.ReplaceAll(path, "*", ".*")
|
||||
ts.handlers[path] = handler
|
||||
}
|
||||
|
||||
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
|
||||
func (ts *ServerTest) TestServer(headerCheck func(w http.ResponseWriter, r *http.Request) bool) *httptest.Server {
|
||||
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path)
|
||||
|
||||
// check auth
|
||||
if headerCheck != nil && !headerCheck(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle /path/* routes.
|
||||
// Note: the * is converted to a .* in register handler for proper regex handling
|
||||
for route, handler := range ts.handlers {
|
||||
// Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered
|
||||
pattern, _ := regexp.Compile("^" + route + "$")
|
||||
if pattern.MatchString(r.URL.Path) {
|
||||
handler(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
http.Error(w, "the resource path doesn't exist", http.StatusNotFound)
|
||||
}))
|
||||
}
|
@ -67,7 +67,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, errors.New("provider not implemented")
|
||||
}
|
||||
|
||||
return balanceProvider.Balance(channel)
|
||||
return balanceProvider.Balance()
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -38,30 +39,36 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
if provider == nil {
|
||||
return errors.New("channel not implemented"), nil
|
||||
}
|
||||
|
||||
newModelName, err := provider.ModelMappingHandler(request.Model)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
request.Model = newModelName
|
||||
|
||||
chatProvider, ok := provider.(providers_base.ChatInterface)
|
||||
if !ok {
|
||||
return errors.New("channel not implemented"), nil
|
||||
}
|
||||
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if modelMap != nil && modelMap[request.Model] != "" {
|
||||
request.Model = modelMap[request.Model]
|
||||
}
|
||||
chatProvider.SetUsage(&types.Usage{})
|
||||
|
||||
response, openAIErrorWithStatusCode := chatProvider.CreateChatCompletion(&request)
|
||||
|
||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError
|
||||
}
|
||||
|
||||
if Usage.CompletionTokens == 0 {
|
||||
usage := chatProvider.GetUsage()
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String()))
|
||||
// 转换为JSON字符串
|
||||
jsonBytes, _ := json.Marshal(response)
|
||||
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, string(jsonBytes)))
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@ -74,9 +81,9 @@ func buildTestRequest() *types.ChatCompletionRequest {
|
||||
Content: "You just need to output 'hi' next.",
|
||||
},
|
||||
},
|
||||
Model: "",
|
||||
MaxTokens: 1,
|
||||
Stream: false,
|
||||
Model: "",
|
||||
// MaxTokens: 1,
|
||||
Stream: false,
|
||||
}
|
||||
return testRequest
|
||||
}
|
||||
|
164
controller/quota.go
Normal file
164
controller/quota.go
Normal file
@ -0,0 +1,164 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type QuotaInfo struct {
|
||||
modelName string
|
||||
promptTokens int
|
||||
preConsumedTokens int
|
||||
modelRatio float64
|
||||
groupRatio float64
|
||||
ratio float64
|
||||
preConsumedQuota int
|
||||
userId int
|
||||
channelId int
|
||||
tokenId int
|
||||
HandelStatus bool
|
||||
}
|
||||
|
||||
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
|
||||
quotaInfo := &QuotaInfo{
|
||||
modelName: modelName,
|
||||
promptTokens: promptTokens,
|
||||
userId: c.GetInt("id"),
|
||||
channelId: c.GetInt("channel_id"),
|
||||
tokenId: c.GetInt("token_id"),
|
||||
HandelStatus: false,
|
||||
}
|
||||
quotaInfo.initQuotaInfo(c.GetString("group"))
|
||||
|
||||
errWithCode := quotaInfo.preQuotaConsumption()
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return quotaInfo, nil
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) initQuotaInfo(groupName string) {
|
||||
modelRatio := common.GetModelRatio(q.modelName)
|
||||
groupRatio := common.GetGroupRatio(groupName)
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
|
||||
|
||||
q.preConsumedTokens = preConsumedTokens
|
||||
q.modelRatio = modelRatio
|
||||
q.groupRatio = groupRatio
|
||||
q.ratio = ratio
|
||||
q.preConsumedQuota = preConsumedQuota
|
||||
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
|
||||
userQuota, err := model.CacheGetUserQuota(q.userId)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if userQuota < q.preConsumedQuota {
|
||||
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if userQuota > 100*q.preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
q.preConsumedQuota = 0
|
||||
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
|
||||
if q.preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
q.HandelStatus = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(q.modelName)
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
|
||||
if q.ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - q.preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
return errors.New("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(q.userId)
|
||||
if err != nil {
|
||||
return errors.New("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
requestTime := 0
|
||||
requestStartTimeValue := ctx.Value("requestStartTime")
|
||||
if requestStartTimeValue != nil {
|
||||
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
||||
if ok {
|
||||
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
||||
}
|
||||
}
|
||||
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
|
||||
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
|
||||
model.UpdateChannelUsedQuota(q.channelId, quota)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if q.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err := q.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
@ -1,11 +1,11 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/common/requester"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -20,33 +20,18 @@ func RelayChat(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, chatRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||
chatRequest.Model = modelMap[chatRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeChatCompletions)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, chatRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
chatRequest.Model = modelName
|
||||
|
||||
chatProvider, ok := provider.(providersBase.ChatInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -56,39 +41,42 @@ func RelayChat(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
|
||||
if chatRequest.Stream {
|
||||
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse]
|
||||
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
}
|
||||
|
||||
fmt.Println(usage)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,11 +1,9 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -20,33 +18,18 @@ func RelayCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, completionRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||
completionRequest.Model = modelMap[completionRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeCompletions)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, completionRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
completionRequest.Model = modelName
|
||||
|
||||
completionProvider, ok := provider.(providersBase.CompletionInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -56,39 +39,38 @@ func RelayCompletions(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
|
||||
if completionRequest.Stream {
|
||||
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseStreamClient[types.CompletionResponse](c, response)
|
||||
} else {
|
||||
response, errWithCode := completionProvider.CreateCompletion(&completionRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
}
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
@ -24,28 +22,13 @@ func RelayEmbeddings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, embeddingsRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeEmbeddings)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, embeddingsRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
embeddingsRequest.Model = modelName
|
||||
|
||||
embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -55,39 +38,29 @@ func RelayEmbeddings(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -33,28 +31,13 @@ func RelayImageEdits(c *gin.Context) {
|
||||
imageEditRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, imageEditRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
|
||||
imageEditRequest.Model = modelMap[imageEditRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeImagesEdits)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
imageEditRequest.Model = modelName
|
||||
|
||||
imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -68,39 +51,29 @@ func RelayImageEdits(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -36,28 +34,13 @@ func RelayImageGenerations(c *gin.Context) {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, imageRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[imageRequest.Model] != "" {
|
||||
imageRequest.Model = modelMap[imageRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, imageRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
imageRequest.Model = modelName
|
||||
|
||||
imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -71,39 +54,29 @@ func RelayImageGenerations(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -28,28 +26,13 @@ func RelayImageVariations(c *gin.Context) {
|
||||
imageEditRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, imageEditRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
|
||||
imageEditRequest.Model = modelMap[imageEditRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeImagesVariations)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
imageEditRequest.Model = modelName
|
||||
|
||||
imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -63,39 +46,29 @@ func RelayImageVariations(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -24,28 +22,13 @@ func RelayModerations(c *gin.Context) {
|
||||
moderationRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, moderationRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
|
||||
moderationRequest.Model = modelMap[moderationRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeModerations)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, moderationRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
moderationRequest.Model = modelName
|
||||
|
||||
moderationProvider, ok := provider.(providersBase.ModerationInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -55,39 +38,29 @@ func RelayModerations(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := moderationProvider.CreateModeration(&moderationRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -20,28 +18,13 @@ func RelaySpeech(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, speechRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[speechRequest.Model] != "" {
|
||||
speechRequest.Model = modelMap[speechRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, speechRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
speechRequest.Model = modelName
|
||||
|
||||
speechProvider, ok := provider.(providersBase.SpeechInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -51,39 +34,29 @@ func RelaySpeech(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := len(speechRequest.Input)
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := speechProvider.CreateSpeech(&speechRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseMultipart(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -20,28 +18,13 @@ func RelayTranscriptions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, audioRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[audioRequest.Model] != "" {
|
||||
audioRequest.Model = modelMap[audioRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, audioRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
audioRequest.Model = modelName
|
||||
|
||||
transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -51,39 +34,29 @@ func RelayTranscriptions(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := 0
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseCustom(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,10 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -20,28 +18,13 @@ func RelayTranslations(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, pass := fetchChannel(c, audioRequest.Model)
|
||||
if pass {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if modelMap != nil && modelMap[audioRequest.Model] != "" {
|
||||
audioRequest.Model = modelMap[audioRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation)
|
||||
if pass {
|
||||
provider, modelName, fail := getProvider(c, audioRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
audioRequest.Model = modelName
|
||||
|
||||
translationProvider, ok := provider.(providersBase.TranslationInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
@ -51,39 +34,29 @@ func RelayTranslations(c *gin.Context) {
|
||||
// 获取Input Tokens
|
||||
promptTokens := 0
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
var errWithCode *types.OpenAIErrorWithStatusCode
|
||||
var usage *types.Usage
|
||||
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens)
|
||||
response, errWithCode := translationProvider.CreateTranslation(&audioRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseCustom(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if quotaInfo.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
} else {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
|
@ -1,25 +1,47 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) {
|
||||
channel, fail := fetchChannel(c, modeName)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
|
||||
provider = providers.GetProvider(channel, c)
|
||||
if provider == nil {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
||||
fail = true
|
||||
return
|
||||
}
|
||||
|
||||
newModelName, err := provider.ModelMappingHandler(modeName)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
fail = true
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetValidFieldName(err error, obj interface{}) string {
|
||||
getObj := reflect.TypeOf(obj)
|
||||
if errs, ok := err.(validator.ValidationErrors); ok {
|
||||
@ -32,17 +54,17 @@ func GetValidFieldName(err error, obj interface{}) string {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) {
|
||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) {
|
||||
channelId, ok := c.Get("channelId")
|
||||
if ok {
|
||||
channel, pass = fetchChannelById(c, channelId.(int))
|
||||
if pass {
|
||||
channel, fail = fetchChannelById(c, channelId.(int))
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
channel, pass = fetchChannelByModel(c, modelName)
|
||||
if pass {
|
||||
channel, fail = fetchChannelByModel(c, modelName)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
|
||||
@ -86,21 +108,6 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool
|
||||
return channel, false
|
||||
}
|
||||
|
||||
func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) {
|
||||
provider := providers.GetProvider(channel, c)
|
||||
if provider == nil {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
||||
return nil, true
|
||||
}
|
||||
|
||||
if !provider.SupportAPI(relayMode) {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API")
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return provider, false
|
||||
}
|
||||
|
||||
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
@ -130,138 +137,81 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func parseModelMapping(modelMapping string) (map[string]string, error) {
|
||||
if modelMapping == "" || modelMapping == "{}" {
|
||||
return nil, nil
|
||||
}
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
|
||||
// 将data转换为 JSON
|
||||
responseBody, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return modelMap, nil
|
||||
}
|
||||
|
||||
type QuotaInfo struct {
|
||||
modelName string
|
||||
promptTokens int
|
||||
preConsumedTokens int
|
||||
modelRatio float64
|
||||
groupRatio float64
|
||||
ratio float64
|
||||
preConsumedQuota int
|
||||
userId int
|
||||
channelId int
|
||||
tokenId int
|
||||
HandelStatus bool
|
||||
}
|
||||
|
||||
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
|
||||
quotaInfo := &QuotaInfo{
|
||||
modelName: modelName,
|
||||
promptTokens: promptTokens,
|
||||
userId: c.GetInt("id"),
|
||||
channelId: c.GetInt("channel_id"),
|
||||
tokenId: c.GetInt("token_id"),
|
||||
HandelStatus: false,
|
||||
}
|
||||
quotaInfo.initQuotaInfo(c.GetString("group"))
|
||||
|
||||
errWithCode := quotaInfo.preQuotaConsumption()
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return quotaInfo, nil
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) initQuotaInfo(groupName string) {
|
||||
modelRatio := common.GetModelRatio(q.modelName)
|
||||
groupRatio := common.GetGroupRatio(groupName)
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
|
||||
|
||||
q.preConsumedTokens = preConsumedTokens
|
||||
q.modelRatio = modelRatio
|
||||
q.groupRatio = groupRatio
|
||||
q.ratio = ratio
|
||||
q.preConsumedQuota = preConsumedQuota
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
|
||||
userQuota, err := model.CacheGetUserQuota(q.userId)
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, err = c.Writer.Write(responseBody)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if userQuota < q.preConsumedQuota {
|
||||
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if userQuota > 100*q.preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
q.preConsumedQuota = 0
|
||||
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
|
||||
if q.preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
q.HandelStatus = true
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(q.modelName)
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
|
||||
if q.ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - q.preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
return errors.New("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(q.userId)
|
||||
if err != nil {
|
||||
return errors.New("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
requestTime := 0
|
||||
requestStartTimeValue := ctx.Value("requestStartTime")
|
||||
if requestStartTimeValue != nil {
|
||||
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
||||
if ok {
|
||||
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
||||
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode {
|
||||
requester.SetEventStreamHeaders(c)
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
if response != nil && len(*response) > 0 {
|
||||
for _, streamResponse := range *response {
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
|
||||
}
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
break
|
||||
}
|
||||
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
|
||||
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
|
||||
model.UpdateChannelUsedQuota(q.channelId, quota)
|
||||
if err != nil {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
break
|
||||
}
|
||||
|
||||
for _, streamResponse := range *response {
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
|
||||
defer resp.Body.Close()
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types.OpenAIErrorWithStatusCode {
|
||||
for k, v := range response.Headers {
|
||||
c.Writer.Header().Set(k, v)
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := c.Writer.Write(response.Body)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -2,29 +2,26 @@ package aigc2d
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
)
|
||||
|
||||
func (p *Aigc2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
func (p *Aigc2dProvider) Balance() (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response base.BalanceResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
p.Channel.UpdateBalance(response.TotalAvailable)
|
||||
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
|
@ -1,20 +1,19 @@
|
||||
package aigc2d
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Aigc2dProviderFactory struct{}
|
||||
|
||||
func (f Aigc2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
return &Aigc2dProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aigc2d.com"),
|
||||
}
|
||||
}
|
||||
|
||||
type Aigc2dProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
||||
|
||||
func (f Aigc2dProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &Aigc2dProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.aigc2d.com"),
|
||||
}
|
||||
}
|
||||
|
@ -3,24 +3,21 @@ package aiproxy
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
func (p *AIProxyProvider) Balance() (float64, error) {
|
||||
fullRequestURL := "https://aiproxy.io/api/report/getUserOverview"
|
||||
headers := make(map[string]string)
|
||||
headers["Api-Key"] = channel.Key
|
||||
headers["Api-Key"] = p.Channel.Key
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response AIProxyUserOverviewResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
@ -29,7 +26,7 @@ func (p *AIProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
||||
}
|
||||
|
||||
channel.UpdateBalance(response.Data.TotalPoints)
|
||||
p.Channel.UpdateBalance(response.Data.TotalPoints)
|
||||
|
||||
return response.Data.TotalPoints, nil
|
||||
}
|
||||
|
@ -1,20 +1,19 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AIProxyProviderFactory struct{}
|
||||
|
||||
func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
return &AIProxyProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"),
|
||||
}
|
||||
}
|
||||
|
||||
type AIProxyProvider struct {
|
||||
*openai.OpenAIProvider
|
||||
}
|
||||
|
||||
func (f AIProxyProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &AIProxyProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.aiproxy.io"),
|
||||
}
|
||||
}
|
||||
|
24
providers/ali/ali_test.go
Normal file
24
providers/ali/ali_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package ali_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/test"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown func()) {
|
||||
server = test.NewTestServer()
|
||||
ts := server.TestServer(func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return test.OpenAICheck(w, r)
|
||||
})
|
||||
ts.Start()
|
||||
teardown = ts.Close
|
||||
|
||||
baseUrl = ts.URL
|
||||
return
|
||||
}
|
||||
|
||||
func getAliChannel(baseUrl string) model.Channel {
|
||||
return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "")
|
||||
}
|
@ -1,32 +1,66 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
type AliProviderFactory struct{}
|
||||
|
||||
type AliProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
// 创建 AliProvider
|
||||
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||
func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f AliProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &AliProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type AliProvider struct {
|
||||
base.BaseProvider
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://dashscope.aliyuncs.com",
|
||||
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var aliError *AliError
|
||||
err := json.NewDecoder(resp.Body).Decode(aliError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(aliError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(aliError *AliError) *types.OpenAIError {
|
||||
if aliError.Code == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: aliError.Message,
|
||||
Type: aliError.Code,
|
||||
Param: aliError.RequestId,
|
||||
Code: aliError.Code,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
|
@ -1,51 +1,116 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 阿里云响应处理
|
||||
func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: aliResponse.Model,
|
||||
Choices: aliResponse.Output.ToChatCompletionChoices(),
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
return
|
||||
type aliStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
lastStreamResponse string
|
||||
}
|
||||
|
||||
const AliEnableSearchModelSuffix = "-internet"
|
||||
|
||||
// 获取聊天请求体
|
||||
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getAliChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
aliResponse := &AliChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToChatOpenai(aliResponse, request)
|
||||
}
|
||||
|
||||
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getAliChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &aliStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
}
|
||||
|
||||
aliRequest := convertFromChatOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天请求体
|
||||
func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.AliError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: request.Model,
|
||||
Choices: response.Output.ToChatCompletionChoices(),
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
|
||||
*p.Usage = *openaiResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 阿里云聊天请求体
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
@ -96,163 +161,68 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
|
||||
}
|
||||
}
|
||||
|
||||
// 聊天
|
||||
func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[5:]
|
||||
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal(*rawLine, &aliResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &types.Usage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
aliResponse := &AliChatResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||
}
|
||||
error := errorHandle(&aliResponse.AliError)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
return
|
||||
|
||||
return h.convertToOpenaiStream(&aliResponse, response)
|
||||
|
||||
}
|
||||
|
||||
// 阿里云响应转OpenAI响应
|
||||
func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||
// chatChoice := aliResponse.Output.ToChatCompletionChoices()
|
||||
// jsonBody, _ := json.MarshalIndent(chatChoice, "", " ")
|
||||
// fmt.Println("requestBody:", string(jsonBody))
|
||||
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
content := aliResponse.Output.Choices[0].Message.StringContent()
|
||||
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Index = aliResponse.Output.Choices[0].Index
|
||||
choice.Delta.Content = aliResponse.Output.Choices[0].Message.StringContent()
|
||||
// fmt.Println("choice.Delta.Content:", chatChoice[0].Message)
|
||||
if aliResponse.Output.Choices[0].FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.Choices[0].FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
choice.Delta.Content = strings.TrimPrefix(content, h.lastStreamResponse)
|
||||
if aliResponse.Output.Choices[0].FinishReason != "" {
|
||||
if aliResponse.Output.Choices[0].FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.Choices[0].FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
if aliResponse.Output.FinishReason != "" {
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
|
||||
h.lastStreamResponse = content
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: aliResponse.Model,
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// 发送流请求
|
||||
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return nil, common.HandleErrorResp(resp)
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
h.Usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
h.Usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
lastResponseText := ""
|
||||
index := 0
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
aliResponse.Model = model
|
||||
aliResponse.Output.Choices[0].Index = index
|
||||
index++
|
||||
response := p.streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Choices[0].Message.StringContent()
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
330
providers/ali/chat_test.go
Normal file
330
providers/ali/chat_test.go
Normal file
@ -0,0 +1,330 @@
|
||||
package ali_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/test"
|
||||
_ "one-api/common/test/init"
|
||||
"one-api/providers"
|
||||
providers_base "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func getChatProvider(url string, context *gin.Context) providers_base.ChatInterface {
|
||||
channel := getAliChannel(url)
|
||||
provider := providers.GetProvider(&channel, context)
|
||||
chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
|
||||
return chatProvider
|
||||
}
|
||||
|
||||
func TestChatCompletions(t *testing.T) {
|
||||
url, server, teardown := setupAliTestServer()
|
||||
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
defer teardown()
|
||||
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionEndpoint)
|
||||
|
||||
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
|
||||
|
||||
chatProvider := getChatProvider(url, context)
|
||||
usage := &types.Usage{}
|
||||
chatProvider.SetUsage(usage)
|
||||
response, errWithCode := chatProvider.CreateChatCompletion(chatRequest)
|
||||
|
||||
assert.Nil(t, errWithCode)
|
||||
assert.IsType(t, &types.Usage{}, usage)
|
||||
assert.Equal(t, 33, usage.TotalTokens)
|
||||
assert.Equal(t, 14, usage.PromptTokens)
|
||||
assert.Equal(t, 19, usage.CompletionTokens)
|
||||
|
||||
// 转换成JSON字符串
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
assert.Fail(t, "json marshal error")
|
||||
}
|
||||
fmt.Println(string(responseBody))
|
||||
|
||||
test.CheckChat(t, response, "qwen-turbo", usage)
|
||||
}
|
||||
|
||||
func TestChatCompletionsError(t *testing.T) {
|
||||
url, server, teardown := setupAliTestServer()
|
||||
context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
defer teardown()
|
||||
server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionErrorEndpoint)
|
||||
|
||||
chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "false")
|
||||
|
||||
chatProvider := getChatProvider(url, context)
|
||||
_, err := chatProvider.CreateChatCompletion(chatRequest)
|
||||
usage := chatProvider.GetUsage()
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, usage)
|
||||
assert.Equal(t, "InvalidParameter", err.Code)
|
||||
}
|
||||
|
||||
// func TestChatCompletionsStream(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
|
||||
|
||||
// usage := &types.Usage{}
|
||||
// chatProvider.SetUsage(usage)
|
||||
// response, errWithCode := chatProvider.CreateChatCompletionStream(chatRequest)
|
||||
// assert.Nil(t, errWithCode)
|
||||
|
||||
// assert.IsType(t, &types.Usage{}, usage)
|
||||
// assert.Equal(t, 16, usage.TotalTokens)
|
||||
// assert.Equal(t, 8, usage.PromptTokens)
|
||||
// assert.Equal(t, 8, usage.CompletionTokens)
|
||||
|
||||
// streamResponseCheck(t, w.Body.String())
|
||||
// }
|
||||
|
||||
// func TestChatCompletionsStreamError(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/text-generation/generation", handleChatCompletionStreamErrorEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("default", "qwen-turbo", "true")
|
||||
|
||||
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
||||
|
||||
// // 打印 context 写入的内容
|
||||
// fmt.Println(w.Body.String())
|
||||
|
||||
// assert.NotNil(t, err)
|
||||
// assert.Nil(t, usage)
|
||||
// }
|
||||
|
||||
// func TestChatImageCompletions(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, _ := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "false")
|
||||
|
||||
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
||||
|
||||
// assert.Nil(t, err)
|
||||
// assert.IsType(t, &types.Usage{}, usage)
|
||||
// assert.Equal(t, 1306, usage.TotalTokens)
|
||||
// assert.Equal(t, 1279, usage.PromptTokens)
|
||||
// assert.Equal(t, 27, usage.CompletionTokens)
|
||||
// }
|
||||
|
||||
// func TestChatImageCompletionsStream(t *testing.T) {
|
||||
// url, server, teardown := setupAliTestServer()
|
||||
// context, w := test.GetContext("POST", "/v1/chat/completions", test.RequestJSONConfig(), nil)
|
||||
// defer teardown()
|
||||
// server.RegisterHandler("/api/v1/services/aigc/multimodal-generation/generation", handleChatImageCompletionStreamEndpoint)
|
||||
|
||||
// channel := getAliChannel(url)
|
||||
// provider := providers.GetProvider(&channel, context)
|
||||
// chatProvider, _ := provider.(providers_base.ChatInterface)
|
||||
// chatRequest := test.GetChatCompletionRequest("image", "qwen-vl-plus", "true")
|
||||
|
||||
// usage, err := chatProvider.ChatAction(chatRequest, 0)
|
||||
|
||||
// fmt.Println(w.Body.String())
|
||||
|
||||
// assert.Nil(t, err)
|
||||
// assert.IsType(t, &types.Usage{}, usage)
|
||||
// assert.Equal(t, 1342, usage.TotalTokens)
|
||||
// assert.Equal(t, 1279, usage.PromptTokens)
|
||||
// assert.Equal(t, 63, usage.CompletionTokens)
|
||||
// streamResponseCheck(t, w.Body.String())
|
||||
// }
|
||||
|
||||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
response := `{"output":{"choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"您好!我可以帮您查询最近的公园,请问您现在所在的位置是哪里呢?"}}]},"usage":{"total_tokens":33,"output_tokens":19,"input_tokens":14},"request_id":"2479f818-9717-9b0b-9769-0d26e873a3f6"}`
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, response)
|
||||
}
|
||||
|
||||
func handleChatCompletionErrorEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
response := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"4883ee8d-f095-94ff-a94a-5ce0a94bc81f"}`
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, response)
|
||||
}
|
||||
|
||||
func handleChatCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 检测头部是否有X-DashScope-SSE: enable
|
||||
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
||||
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data := `{"output":{"choices":[{"message":{"content":"你好!","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":10,"input_tokens":8,"output_tokens":2},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:2\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":"有什么我可以帮助你的吗?","role":"assistant"},"finish_reason":"null"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:3\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":"","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":16,"input_tokens":8,"output_tokens":8},"request_id":"215a2614-5486-936c-8d42-3b472d6fbd1c"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleChatCompletionStreamErrorEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 检测头部是否有X-DashScope-SSE: enable
|
||||
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
||||
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:error\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/400\n")...)
|
||||
//nolint:lll
|
||||
data := `{"code":"InvalidParameter","message":"Role must be user or assistant and Content length must be greater than 0","request_id":"6b932ba9-41bd-9ad3-b430-24bc1e125880"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleChatImageCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
response := `{"output":{"finish_reason":"stop","choices":[{"message":{"role":"assistant","content":[{"text":"这张照片展示的是一个海滩的场景,但是并没有明确指出具体的位置。可以看到海浪和日落背景下的沙滩景色。"}]}}]},"usage":{"output_tokens":27,"input_tokens":1279,"image_tokens":1247},"request_id":"a360d53b-b993-927f-9a68-bef6b2b2042e"}`
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, response)
|
||||
}
|
||||
|
||||
func handleChatImageCompletionStreamEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 检测头部是否有X-DashScope-SSE: enable
|
||||
if r.Header.Get("X-DashScope-SSE") != "enable" {
|
||||
http.Error(w, "Header X-DashScope-SSE not found", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataBytes = append(dataBytes, []byte("id:1\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data := `{"output":{"choices":[{"message":{"content":[{"text":"这张"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":1,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:2\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片"}],"role":"assistant"}}],"finish_reason":"null"},"usage":{"input_tokens":1279,"output_tokens":2,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("id:3\n")...)
|
||||
dataBytes = append(dataBytes, []byte("event:result\n")...)
|
||||
dataBytes = append(dataBytes, []byte(":HTTP_STATUS/200\n")...)
|
||||
//nolint:lll
|
||||
data = `{"output":{"choices":[{"message":{"content":[{"text":"这张照片展示的是一个海滩的场景,具体来说是在日落时分。由于没有明显的地标或建筑物等特征可以辨认出具体的地点信息,所以无法确定这是哪个地方的海滩。但是根据图像中的元素和环境特点,我们可以推测这可能是一个位于沿海地区的沙滩海岸线。"}],"role":"assistant"}}],"finish_reason":"stop"},"usage":{"input_tokens":1279,"output_tokens":63,"image_tokens":1247},"request_id":"37bead8b-d87a-98f8-9193-b9e2da9d2451"}`
|
||||
dataBytes = append(dataBytes, []byte("data:"+data+"\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseCheck(t *testing.T, response string) {
|
||||
// 以换行符分割response
|
||||
lines := strings.Split(response, "\n\n")
|
||||
// 如果最后一行为空,则删除最后一行
|
||||
if lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1]
|
||||
}
|
||||
|
||||
// 循环遍历每一行
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
// assert判断 是否以data: 开头
|
||||
assert.True(t, strings.HasPrefix(line, "data: "))
|
||||
}
|
||||
|
||||
// 检测最后一行是否以data: [DONE] 结尾
|
||||
assert.True(t, strings.HasSuffix(lines[len(lines)-1], "data: [DONE]"))
|
||||
// 检测倒数第二行是否存在 `"finish_reason":"stop"`
|
||||
assert.True(t, strings.Contains(lines[len(lines)-2], `"finish_reason":"stop"`))
|
||||
}
|
@ -6,40 +6,37 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
// 嵌入请求处理
|
||||
func (aliResponse *AliEmbeddingResponse) ResponseHandler(resp *http.Response) (any, *types.OpenAIErrorWithStatusCode) {
|
||||
if aliResponse.Code != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
aliRequest := convertFromEmbeddingOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
aliResponse := &AliEmbeddingResponse{}
|
||||
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range aliResponse.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
|
||||
return openAIEmbeddingResponse, nil
|
||||
return p.convertToEmbeddingOpenai(aliResponse, request)
|
||||
}
|
||||
|
||||
// 获取嵌入请求体
|
||||
func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
@ -50,24 +47,36 @@ func (p *AliProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AliProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
||||
errWithCode = p.SendRequest(req, aliEmbeddingResponse, false)
|
||||
if errWithCode != nil {
|
||||
func (p *AliProvider) convertToEmbeddingOpenai(response *AliEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.AliError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
||||
|
||||
return usage, nil
|
||||
openaiResponse = &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(response.Output.Embeddings)),
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: response.Usage.TotalTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openaiResponse.Data = append(openaiResponse.Data, types.Embedding{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
|
||||
*p.Usage = *openaiResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -52,7 +52,8 @@ type AliChoice struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Choices []types.ChatCompletionChoice `json:"choices"`
|
||||
Choices []types.ChatCompletionChoice `json:"choices"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
|
||||
@ -70,7 +71,6 @@ func (o *AliOutput) ToChatCompletionChoices() []types.ChatCompletionChoice {
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
Model string `json:"model,omitempty"`
|
||||
AliError
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package api2d
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
)
|
||||
@ -11,15 +10,14 @@ func (p *Api2dProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response base.BalanceResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
@ -1,18 +1,17 @@
|
||||
package api2d
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Api2dProviderFactory struct{}
|
||||
|
||||
// 创建 Api2dProvider
|
||||
func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f Api2dProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &Api2dProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://oa.api2d.net"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package api2gpt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
)
|
||||
@ -11,15 +10,14 @@ func (p *Api2gptProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response base.BalanceResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
@ -1,17 +1,16 @@
|
||||
package api2gpt
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Api2gptProviderFactory struct{}
|
||||
|
||||
func (f Api2gptProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f Api2gptProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &Api2gptProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.api2gpt.com"),
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.api2gpt.com"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,30 +1,23 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AzureProviderFactory struct{}
|
||||
|
||||
// 创建 AzureProvider
|
||||
func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f AzureProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
config := getAzureConfig()
|
||||
return &AzureProvider{
|
||||
OpenAIProvider: openai.OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "",
|
||||
Completions: "/completions",
|
||||
ChatCompletions: "/chat/completions",
|
||||
Embeddings: "/embeddings",
|
||||
AudioTranscriptions: "/audio/transcriptions",
|
||||
AudioTranslations: "/audio/translations",
|
||||
ImagesGenerations: "/images/generations",
|
||||
// ImagesEdit: "/images/edit",
|
||||
// ImagesVariations: "/images/variations",
|
||||
Context: c,
|
||||
// AudioSpeech: "/audio/speech",
|
||||
Config: config,
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, openai.RequestErrorHandle),
|
||||
},
|
||||
IsAzure: true,
|
||||
BalanceAction: false,
|
||||
@ -32,6 +25,18 @@ func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "",
|
||||
Completions: "/completions",
|
||||
ChatCompletions: "/chat/completions",
|
||||
Embeddings: "/embeddings",
|
||||
AudioTranscriptions: "/audio/transcriptions",
|
||||
AudioTranslations: "/audio/translations",
|
||||
ImagesGenerations: "/images/generations",
|
||||
}
|
||||
}
|
||||
|
||||
type AzureProvider struct {
|
||||
openai.OpenAIProvider
|
||||
}
|
||||
|
@ -10,13 +10,60 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Status == "canceled" || c.Status == "failed" {
|
||||
func (p *AzureProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
if !openai.IsWithinRange(request.Model, request.N) {
|
||||
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
var response *types.ImageResponse
|
||||
var resp *http.Response
|
||||
if request.Model == "dall-e-2" {
|
||||
imageAzureResponse := &ImageAzureResponse{}
|
||||
resp, errWithCode = p.Requester.SendRequest(req, imageAzureResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
response, errWithCode = p.ResponseAzureImageHandler(resp, imageAzureResponse)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
} else {
|
||||
var openaiResponse openai.OpenAIProviderImageResponse
|
||||
_, errWithCode = p.Requester.SendRequest(req, &openaiResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 检测是否错误
|
||||
openaiErr := openai.ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
response = &openaiResponse.ImageResponse
|
||||
}
|
||||
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return response, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *AzureProvider) ResponseAzureImageHandler(resp *http.Response, azure *ImageAzureResponse) (OpenAIResponse *types.ImageResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if azure.Status == "canceled" || azure.Status == "failed" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: c.Error.Message,
|
||||
Message: azure.Error.Message,
|
||||
Type: "one_api_error",
|
||||
Code: c.Error.Code,
|
||||
Code: azure.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
@ -28,8 +75,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
|
||||
return nil, common.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header))
|
||||
req, err := p.Requester.NewRequest("GET", operation_location, p.Requester.WithHeader(p.GetRequestHeaders()))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@ -38,7 +84,7 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
|
||||
for i := 0; i < 3; i++ {
|
||||
// 休眠 2 秒
|
||||
time.Sleep(2 * time.Second)
|
||||
_, errWithCode = common.SendRequest(req, &getImageAzureResponse, false, c.Proxy)
|
||||
_, errWithCode = p.Requester.SendRequest(req, &getImageAzureResponse, false)
|
||||
fmt.Println("getImageAzureResponse", getImageAzureResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
@ -47,57 +93,17 @@ func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIRespons
|
||||
if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: c.Error.Message,
|
||||
Message: getImageAzureResponse.Error.Message,
|
||||
Type: "get_images_request_failed",
|
||||
Code: c.Error.Code,
|
||||
Code: getImageAzureResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
if getImageAzureResponse.Status == "succeeded" {
|
||||
return getImageAzureResponse.Result, nil
|
||||
return &getImageAzureResponse.Result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, common.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Model == "dall-e-2" {
|
||||
imageAzureResponse := &ImageAzureResponse{
|
||||
Header: headers,
|
||||
Proxy: p.Channel.Proxy,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, imageAzureResponse, false)
|
||||
} else {
|
||||
openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
}
|
||||
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -1,21 +1,23 @@
|
||||
package azureSpeech
|
||||
|
||||
import (
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
type AzureSpeechProviderFactory struct{}
|
||||
|
||||
// 创建 AliProvider
|
||||
func (f AzureSpeechProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f AzureSpeechProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &AzureSpeechProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "",
|
||||
AudioSpeech: "/cognitiveservices/v1",
|
||||
Context: c,
|
||||
Config: base.ProviderConfig{
|
||||
AudioSpeech: "/cognitiveservices/v1",
|
||||
},
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -55,9 +55,12 @@ func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest)
|
||||
|
||||
}
|
||||
|
||||
func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
|
||||
func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeAudioSpeech)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
responseFormatr := outputFormatMap[request.ResponseFormat]
|
||||
if responseFormatr == "" {
|
||||
@ -67,22 +70,19 @@ func (p *AzureSpeechProvider) SpeechAction(request *types.SpeechAudioRequest, is
|
||||
|
||||
requestBody := p.getRequestBody(request)
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(requestBody), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
errWithCode = p.SendRequestRaw(req)
|
||||
var resp *http.Response
|
||||
resp, errWithCode = p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return
|
||||
return resp, nil
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
package baichuan
|
||||
|
||||
import (
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
@ -12,19 +12,26 @@ type BaichuanProviderFactory struct{}
|
||||
|
||||
// 创建 BaichuanProvider
|
||||
// https://platform.baichuan-ai.com/docs/api
|
||||
func (f BaichuanProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f BaichuanProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &BaichuanProvider{
|
||||
OpenAIProvider: openai.OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://api.baichuan-ai.com",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, openai.RequestErrorHandle),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.baichuan-ai.com",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
}
|
||||
}
|
||||
|
||||
type BaichuanProvider struct {
|
||||
openai.OpenAIProvider
|
||||
}
|
||||
|
@ -3,31 +3,61 @@ package baichuan
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/providers/openai"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (baichuanResponse *BaichuanChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baichuanResponse.Error.Message != "" {
|
||||
func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, requestBody)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &BaichuanChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 检测是否错误
|
||||
openaiErr := openai.ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: baichuanResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
ID: baichuanResponse.ID,
|
||||
Object: baichuanResponse.Object,
|
||||
Created: baichuanResponse.Created,
|
||||
Model: baichuanResponse.Model,
|
||||
Choices: baichuanResponse.Choices,
|
||||
Usage: baichuanResponse.Usage,
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.ChatCompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
chatHandler := openai.OpenAIStreamHandler{
|
||||
Usage: p.Usage,
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
}
|
||||
|
||||
// 获取聊天请求体
|
||||
@ -55,46 +85,3 @@ func (p *BaichuanProvider) getChatRequestBody(request *types.ChatCompletionReque
|
||||
TopK: request.N,
|
||||
}
|
||||
}
|
||||
|
||||
// 聊天
|
||||
func (p *BaichuanProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
openAIProviderChatStreamResponse := &openai.OpenAIProviderChatStreamResponse{}
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
baichuanResponse := &BaichuanChatResponse{}
|
||||
errWithCode = p.SendRequest(req, baichuanResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = baichuanResponse.Usage
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -4,35 +4,65 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 定义供应商工厂
|
||||
type BaiduProviderFactory struct{}
|
||||
|
||||
// 创建 BaiduProvider
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
// 创建 BaiduProvider
|
||||
type BaiduProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func (f BaiduProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &BaiduProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://aip.baidubce.com",
|
||||
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://aip.baidubce.com",
|
||||
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||
}
|
||||
}
|
||||
|
||||
type BaiduProvider struct {
|
||||
base.BaseProvider
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var baiduError *BaiduError
|
||||
err := json.NewDecoder(resp.Body).Decode(baiduError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(baiduError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(baiduError *BaiduError) *types.OpenAIError {
|
||||
if baiduError.ErrorMsg == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: baiduError.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Code: baiduError.ErrorCode,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
@ -92,32 +122,21 @@ func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessTo
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
|
||||
url := fmt.Sprintf(p.Config.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
|
||||
|
||||
var headers = map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("POST", url, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
common.PutHttpClient(httpClient)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var accessToken BaiduAccessToken
|
||||
err = json.NewDecoder(resp.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
_, errWithCode := p.Requester.SendRequest(req, &accessToken, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
|
@ -1,79 +1,155 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
type baiduStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getBaiduChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
baiduResponse := &BaiduChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, baiduResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToChatOpenai(baiduResponse, request)
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getBaiduChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &baiduStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
baiduRequest := convertFromChatOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(baiduRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) convertToChatOpenai(response *BaiduChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.BaiduError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
// Content: baiduResponse.Result,
|
||||
},
|
||||
FinishReason: base.StopFinishReason,
|
||||
FinishReason: types.FinishReasonStop,
|
||||
}
|
||||
|
||||
if baiduResponse.FunctionCall != nil {
|
||||
if baiduResponse.FunctionCate == "tool" {
|
||||
if response.FunctionCall != nil {
|
||||
if request.Tools != nil {
|
||||
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||
{
|
||||
Id: baiduResponse.Id,
|
||||
Id: response.Id,
|
||||
Type: "function",
|
||||
Function: *baiduResponse.FunctionCall,
|
||||
Function: response.FunctionCall,
|
||||
},
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||
choice.FinishReason = types.FinishReasonToolCalls
|
||||
} else {
|
||||
choice.Message.FunctionCall = baiduResponse.FunctionCall
|
||||
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||
choice.Message.FunctionCall = response.FunctionCall
|
||||
choice.FinishReason = types.FinishReasonFunctionCall
|
||||
}
|
||||
} else {
|
||||
choice.Message.Content = baiduResponse.Result
|
||||
choice.Message.Content = response.Result
|
||||
}
|
||||
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
ID: baiduResponse.Id,
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: baiduResponse.Created,
|
||||
Model: request.Model,
|
||||
Created: response.Created,
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Usage: baiduResponse.Usage,
|
||||
Usage: response.Usage,
|
||||
}
|
||||
|
||||
*p.Usage = *openaiResponse.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
if message.Role == types.ChatMessageRoleSystem {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Role: types.ChatMessageRoleUser,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
Role: types.ChatMessageRoleAssistant,
|
||||
Content: "Okay",
|
||||
})
|
||||
} else if message.Role == types.ChatMessageRoleFunction {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: types.ChatMessageRoleAssistant,
|
||||
Content: "Okay",
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: types.ChatMessageRoleUser,
|
||||
Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(),
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
@ -101,154 +177,82 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
return baiduChatRequest
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &baiduResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate())
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
baiduChatRequest := &BaiduChatResponse{
|
||||
Model: request.Model,
|
||||
FunctionCate: request.GetFunctionCate(),
|
||||
}
|
||||
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = baiduChatRequest.Usage
|
||||
if baiduResponse.IsEnd {
|
||||
*isFinished = true
|
||||
}
|
||||
return
|
||||
|
||||
return h.convertToOpenaiStream(&baiduResponse, response)
|
||||
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
choice := types.ChatCompletionStreamChoice{
|
||||
Index: 0,
|
||||
Delta: types.ChatCompletionStreamChoiceDelta{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
|
||||
if baiduResponse.FunctionCall != nil {
|
||||
if baiduResponse.FunctionCate == "tool" {
|
||||
if h.Request.Tools != nil {
|
||||
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||
{
|
||||
Id: baiduResponse.Id,
|
||||
Type: "function",
|
||||
Function: *baiduResponse.FunctionCall,
|
||||
Function: baiduResponse.FunctionCall,
|
||||
},
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||
choice.FinishReason = types.FinishReasonToolCalls
|
||||
} else {
|
||||
choice.Delta.FunctionCall = baiduResponse.FunctionCall
|
||||
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||
choice.FinishReason = types.FinishReasonFunctionCall
|
||||
}
|
||||
} else {
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
choice.FinishReason = types.FinishReasonStop
|
||||
}
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
chatCompletion := types.ChatCompletionStreamResponse{
|
||||
ID: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: baiduResponse.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
Model: h.Request.Model,
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
usage = &types.Usage{}
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return nil, common.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
baiduResponse.FunctionCate = functionCate
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
baiduResponse.Model = model
|
||||
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return usage, nil
|
||||
|
||||
if baiduResponse.FunctionCall == nil {
|
||||
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletion)
|
||||
} else {
|
||||
choices := choice.ConvertOpenaiStream()
|
||||
for _, choice := range choices {
|
||||
chatCompletionCopy := chatCompletion
|
||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletionCopy)
|
||||
}
|
||||
}
|
||||
|
||||
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -6,33 +6,63 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||
func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
aliRequest := convertFromEmbeddingOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(aliRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
baiduResponse := &BaiduEmbeddingResponse{}
|
||||
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, baiduResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToEmbeddingOpenai(baiduResponse, request)
|
||||
}
|
||||
|
||||
func convertFromEmbeddingOpenai(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
func (p *BaiduProvider) convertToEmbeddingOpenai(response *BaiduEmbeddingResponse, request *types.EmbeddingRequest) (openaiResponse *types.EmbeddingResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.BaiduError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]types.Embedding, 0, len(baiduResponse.Data)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: &baiduResponse.Usage,
|
||||
Data: make([]types.Embedding, 0, len(response.Data)),
|
||||
Model: request.Model,
|
||||
Usage: &response.Usage,
|
||||
}
|
||||
|
||||
for _, item := range baiduResponse.Data {
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
@ -40,30 +70,7 @@ func (baiduResponse *BaiduEmbeddingResponse) ResponseHandler(resp *http.Response
|
||||
})
|
||||
}
|
||||
|
||||
*p.Usage = response.Usage
|
||||
|
||||
return openAIEmbeddingResponse, nil
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody := p.getEmbeddingsRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
|
||||
errWithCode = p.SendRequest(req, baiduEmbeddingResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
usage = &baiduEmbeddingResponse.Usage
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
@ -3,9 +3,9 @@ package base
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
@ -13,11 +13,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
var StopFinishReasonToolFunction = "tool_calls"
|
||||
var StopFinishReasonCallFunction = "function_call"
|
||||
|
||||
type BaseProvider struct {
|
||||
type ProviderConfig struct {
|
||||
BaseURL string
|
||||
Completions string
|
||||
ChatCompletions string
|
||||
@ -29,8 +25,15 @@ type BaseProvider struct {
|
||||
ImagesGenerations string
|
||||
ImagesEdit string
|
||||
ImagesVariations string
|
||||
Context *gin.Context
|
||||
Channel *model.Channel
|
||||
}
|
||||
|
||||
type BaseProvider struct {
|
||||
OriginalModel string
|
||||
Usage *types.Usage
|
||||
Config ProviderConfig
|
||||
Context *gin.Context
|
||||
Channel *model.Channel
|
||||
Requester *requester.HTTPRequester
|
||||
}
|
||||
|
||||
// 获取基础URL
|
||||
@ -39,11 +42,7 @@ func (p *BaseProvider) GetBaseURL() string {
|
||||
return p.Channel.GetBaseURL()
|
||||
}
|
||||
|
||||
return p.BaseURL
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SetChannel(channel *model.Channel) {
|
||||
p.Channel = channel
|
||||
return p.Config.BaseURL
|
||||
}
|
||||
|
||||
// 获取完整请求URL
|
||||
@ -62,104 +61,85 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true, p.Channel.Proxy)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
openAIResponse, openAIErrorWithStatusCode := response.ResponseHandler(resp)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if rawOutput {
|
||||
for k, v := range resp.Header {
|
||||
p.Context.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err := io.Copy(p.Context.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
jsonResponse, err := json.Marshal(openAIResponse)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = p.Context.Writer.Write(jsonResponse)
|
||||
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
func (p *BaseProvider) GetUsage() *types.Usage {
|
||||
return p.Usage
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理响应
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp)
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
p.Context.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(p.Context.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil
|
||||
func (p *BaseProvider) SetUsage(usage *types.Usage) {
|
||||
p.Usage = usage
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SupportAPI(relayMode int) bool {
|
||||
func (p *BaseProvider) SetContext(c *gin.Context) {
|
||||
p.Context = c
|
||||
}
|
||||
|
||||
func (p *BaseProvider) SetOriginalModel(ModelName string) {
|
||||
p.OriginalModel = ModelName
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetOriginalModel() string {
|
||||
return p.OriginalModel
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetChannel() *model.Channel {
|
||||
return p.Channel
|
||||
}
|
||||
|
||||
func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
|
||||
p.OriginalModel = modelName
|
||||
|
||||
modelMapping := p.Channel.GetModelMapping()
|
||||
|
||||
if modelMapping == "" || modelMapping == "{}" {
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if modelMap[modelName] != "" {
|
||||
return modelMap[modelName], nil
|
||||
}
|
||||
|
||||
return modelName, nil
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetAPIUri(relayMode int) string {
|
||||
switch relayMode {
|
||||
case common.RelayModeChatCompletions:
|
||||
return p.ChatCompletions != ""
|
||||
return p.Config.ChatCompletions
|
||||
case common.RelayModeCompletions:
|
||||
return p.Completions != ""
|
||||
return p.Config.Completions
|
||||
case common.RelayModeEmbeddings:
|
||||
return p.Embeddings != ""
|
||||
return p.Config.Embeddings
|
||||
case common.RelayModeAudioSpeech:
|
||||
return p.AudioSpeech != ""
|
||||
return p.Config.AudioSpeech
|
||||
case common.RelayModeAudioTranscription:
|
||||
return p.AudioTranscriptions != ""
|
||||
return p.Config.AudioTranscriptions
|
||||
case common.RelayModeAudioTranslation:
|
||||
return p.AudioTranslations != ""
|
||||
return p.Config.AudioTranslations
|
||||
case common.RelayModeModerations:
|
||||
return p.Moderation != ""
|
||||
return p.Config.Moderation
|
||||
case common.RelayModeImagesGenerations:
|
||||
return p.ImagesGenerations != ""
|
||||
return p.Config.ImagesGenerations
|
||||
case common.RelayModeImagesEdits:
|
||||
return p.ImagesEdit != ""
|
||||
return p.Config.ImagesEdit
|
||||
case common.RelayModeImagesVariations:
|
||||
return p.ImagesVariations != ""
|
||||
return p.Config.ImagesVariations
|
||||
default:
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types.OpenAIErrorWithStatusCode) {
|
||||
url = p.GetAPIUri(relayMode)
|
||||
if url == "" {
|
||||
err = common.StringErrorWrapper("The API interface is not supported", "unsupported_api", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
7
providers/base/handler.go
Normal file
7
providers/base/handler.go
Normal file
@ -0,0 +1,7 @@
|
||||
package base
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type BaseHandler struct {
|
||||
Usage *types.Usage
|
||||
}
|
@ -2,84 +2,108 @@ package base
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Requestable interface {
|
||||
types.CompletionRequest | types.ChatCompletionRequest | types.EmbeddingRequest | types.ModerationRequest | types.SpeechAudioRequest | types.AudioRequest | types.ImageRequest | types.ImageEditRequest
|
||||
}
|
||||
|
||||
// 基础接口
|
||||
type ProviderInterface interface {
|
||||
GetBaseURL() string
|
||||
GetFullRequestURL(requestURL string, modelName string) string
|
||||
GetRequestHeaders() (headers map[string]string)
|
||||
SupportAPI(relayMode int) bool
|
||||
SetChannel(channel *model.Channel)
|
||||
// 获取基础URL
|
||||
// GetBaseURL() string
|
||||
// 获取完整请求URL
|
||||
// GetFullRequestURL(requestURL string, modelName string) string
|
||||
// 获取请求头
|
||||
// GetRequestHeaders() (headers map[string]string)
|
||||
// 获取用量
|
||||
GetUsage() *types.Usage
|
||||
// 设置用量
|
||||
SetUsage(usage *types.Usage)
|
||||
// 设置Context
|
||||
SetContext(c *gin.Context)
|
||||
// 设置原始模型
|
||||
SetOriginalModel(ModelName string)
|
||||
// 获取原始模型
|
||||
GetOriginalModel() string
|
||||
|
||||
// SupportAPI(relayMode int) bool
|
||||
GetChannel() *model.Channel
|
||||
ModelMappingHandler(modelName string) (string, error)
|
||||
}
|
||||
|
||||
// 完成接口
|
||||
type CompletionInterface interface {
|
||||
ProviderInterface
|
||||
CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 聊天接口
|
||||
type ChatInterface interface {
|
||||
ProviderInterface
|
||||
ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 嵌入接口
|
||||
type EmbeddingsInterface interface {
|
||||
ProviderInterface
|
||||
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 审查接口
|
||||
type ModerationInterface interface {
|
||||
ProviderInterface
|
||||
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 文字转语音接口
|
||||
type SpeechInterface interface {
|
||||
ProviderInterface
|
||||
SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 语音转文字接口
|
||||
type TranscriptionsInterface interface {
|
||||
ProviderInterface
|
||||
TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 语音翻译接口
|
||||
type TranslationInterface interface {
|
||||
ProviderInterface
|
||||
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片生成接口
|
||||
type ImageGenerationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片编辑接口
|
||||
type ImageEditsInterface interface {
|
||||
ProviderInterface
|
||||
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type ImageVariationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
Balance(channel *model.Channel) (float64, error)
|
||||
Balance() (float64, error)
|
||||
}
|
||||
|
||||
type ProviderResponseHandler interface {
|
||||
// 响应处理函数
|
||||
ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
// type ProviderResponseHandler interface {
|
||||
// // 响应处理函数
|
||||
// ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
// }
|
||||
|
@ -1,20 +1,23 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type ClaudeProviderFactory struct{}
|
||||
|
||||
// 创建 ClaudeProvider
|
||||
func (f ClaudeProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f ClaudeProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &ClaudeProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
ChatCompletions: "/v1/complete",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -23,6 +26,36 @@ type ClaudeProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
ChatCompletions: "/v1/complete",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var claudeError *ClaudeResponseError
|
||||
err := json.NewDecoder(resp.Body).Decode(claudeError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(claudeError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(claudeError *ClaudeResponseError) *types.OpenAIError {
|
||||
if claudeError.Error.Type == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: claudeError.Error.Message,
|
||||
Type: claudeError.Error.Type,
|
||||
Code: claudeError.Error.Type,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
@ -41,9 +74,9 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
return types.FinishReasonStop
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
return types.FinishReasonLength
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
|
@ -1,56 +1,86 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Model: claudeResponse.Model,
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
|
||||
claudeResponse.Usage.CompletionTokens = completionTokens
|
||||
claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens
|
||||
|
||||
fullTextResponse.Usage = claudeResponse.Usage
|
||||
|
||||
return fullTextResponse, nil
|
||||
type claudeStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) {
|
||||
func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
claudeResponse := &ClaudeResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, claudeResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToChatOpenai(claudeResponse, request)
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &claudeStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_claude_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
claudeRequest := convertFromChatOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(claudeRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: request.Model,
|
||||
Prompt: "",
|
||||
@ -80,138 +110,84 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.ClaudeResponseError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: strings.TrimPrefix(response.Completion, " "),
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(response.StopReason),
|
||||
}
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Model: response.Model,
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
completionTokens := common.CountTokenText(response.Completion, response.Model)
|
||||
response.Usage.CompletionTokens = completionTokens
|
||||
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(responseText, request.Model),
|
||||
}
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
openaiResponse.Usage = response.Usage
|
||||
|
||||
} else {
|
||||
var claudeResponse = &ClaudeResponse{
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
}
|
||||
errWithCode = p.SendRequest(req, claudeResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = claudeResponse.Usage
|
||||
}
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return openaiResponse, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse {
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
var claudeResponse *ClaudeResponse
|
||||
err := json.Unmarshal(*rawLine, claudeResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
if claudeResponse.StopReason == "stop_sequence" {
|
||||
*isFinished = true
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(claudeResponse, response)
|
||||
}
|
||||
|
||||
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var response types.ChatCompletionStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), ""
|
||||
chatCompletion := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
*response = append(*response, chatCompletion)
|
||||
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||
return i + 4, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if !strings.HasPrefix(data, "event: completion") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
responseText += claudeResponse.Completion
|
||||
response := p.streamResponseClaude2OpenAI(&claudeResponse)
|
||||
response.ID = responseId
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
|
||||
|
||||
return nil, responseText
|
||||
return nil
|
||||
}
|
||||
|
@ -23,10 +23,13 @@ type ClaudeRequest struct {
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeResponseError struct {
|
||||
Error ClaudeError `json:"error,omitempty"`
|
||||
}
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
ClaudeResponseError
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package closeai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
@ -10,15 +9,14 @@ func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error)
|
||||
fullRequestURL := p.GetFullRequestURL("/dashboard/billing/credit_grants", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response OpenAICreditGrants
|
||||
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &response, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
@ -1,18 +1,17 @@
|
||||
package closeai
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CloseaiProviderFactory struct{}
|
||||
|
||||
// 创建 CloseaiProvider
|
||||
func (f CloseaiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f CloseaiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &CloseaiProxyProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.closeai-proxy.xyz"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,22 +1,25 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiProviderFactory struct{}
|
||||
|
||||
// 创建 ClaudeProvider
|
||||
func (f GeminiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f GeminiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &GeminiProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -25,6 +28,37 @@ type GeminiProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var geminiError *GeminiErrorResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(geminiError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(geminiError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(geminiError *GeminiErrorResponse) *types.OpenAIError {
|
||||
if geminiError.Error.Message == "" {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: geminiError.Error.Message,
|
||||
Type: "gemini_error",
|
||||
Param: geminiError.Error.Status,
|
||||
Code: geminiError.Error.Code,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
version := "v1"
|
||||
|
@ -1,14 +1,12 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/image"
|
||||
"one-api/providers/base"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
@ -17,50 +15,78 @@ const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if len(response.Candidates) == 0 {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
fullTextResponse := &types.ChatCompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: i,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: base.StopFinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
|
||||
response.Usage.CompletionTokens = completionTokens
|
||||
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
|
||||
|
||||
return fullTextResponse, nil
|
||||
type geminiStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
geminiChatResponse := &GeminiChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, geminiChatResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return p.convertToChatOpenai(geminiChatResponse, request)
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &geminiStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url := "generateContent"
|
||||
if request.Stream {
|
||||
url = "streamGenerateContent?alt=sse"
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
geminiRequest, errWithCode := convertFromChatOpenai(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(geminiRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatRequest, *types.OpenAIErrorWithStatusCode) {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(request.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
@ -160,144 +186,99 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, errWithCode := p.getChatRequestBody(request)
|
||||
if errWithCode != nil {
|
||||
func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.GeminiErrorResponse)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
fullRequestURL := p.GetFullRequestURL("generateContent", request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: request.Model,
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(responseText, request.Model),
|
||||
}
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
var geminiResponse = &GeminiChatResponse{
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: i,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: types.FinishReasonStop,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, geminiResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
}
|
||||
|
||||
usage = geminiResponse.Usage
|
||||
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
|
||||
|
||||
p.Usage.CompletionTokens = completionTokens
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens + completionTokens
|
||||
openaiResponse.Usage = p.Usage
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
|
||||
// var choice types.ChatCompletionStreamChoice
|
||||
// choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
// choice.FinishReason = &base.StopFinishReason
|
||||
// var response types.ChatCompletionStreamResponse
|
||||
// response.Object = "chat.completion.chunk"
|
||||
// response.Model = "gemini"
|
||||
// response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
// return &response
|
||||
// }
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
defer req.Body.Close()
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := json.Unmarshal(*rawLine, &geminiResponse)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), ""
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
error := errorHandle(&geminiResponse.GeminiErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return h.convertToOpenaiStream(&geminiResponse, response)
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
|
||||
|
||||
for i, candidate := range geminiResponse.Candidates {
|
||||
choice := types.ChatCompletionStreamChoice{
|
||||
Index: i,
|
||||
Delta: types.ChatCompletionStreamChoiceDelta{
|
||||
Content: candidate.Content.Parts[0].Text,
|
||||
},
|
||||
FinishReason: types.FinishReasonStop,
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
}
|
||||
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: h.Request.Model,
|
||||
Choices: choices,
|
||||
}
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -42,11 +42,22 @@ type GeminiChatGenerationConfig struct {
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type GeminiErrorResponse struct {
|
||||
Error GeminiError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
GeminiErrorResponse
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"time"
|
||||
)
|
||||
@ -16,15 +15,14 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var subscription OpenAISubscriptionResponse
|
||||
_, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &subscription, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
@ -37,12 +35,15 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
}
|
||||
|
||||
fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "")
|
||||
req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err = p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
usage := OpenAIUsageResponse{}
|
||||
_, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
|
||||
_, errWithCode = p.Requester.SendRequest(req, &usage, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errWithCode
|
||||
}
|
||||
|
||||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||
channel.UpdateBalance(balance)
|
||||
|
@ -1,69 +1,100 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIProviderFactory struct{}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
openAIProvider := CreateOpenAIProvider(c, "")
|
||||
openAIProvider.BalanceAction = true
|
||||
return openAIProvider
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
base.BaseProvider
|
||||
IsAzure bool
|
||||
BalanceAction bool
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
openAIProvider := CreateOpenAIProvider(channel, "https://api.openai.com")
|
||||
openAIProvider.BalanceAction = true
|
||||
return openAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
|
||||
config := getOpenAIConfig(baseURL)
|
||||
|
||||
return &OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
Moderation: "/v1/moderations",
|
||||
AudioSpeech: "/v1/audio/speech",
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
Context: c,
|
||||
Config: config,
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, RequestErrorHandle),
|
||||
},
|
||||
IsAzure: false,
|
||||
BalanceAction: true,
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAIConfig(baseURL string) base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
Moderation: "/v1/moderations",
|
||||
AudioSpeech: "/v1/audio/speech",
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var errorResponse *types.OpenAIErrorResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(errorResponse)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrorHandle(errorResponse)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func ErrorHandle(openaiError *types.OpenAIErrorResponse) *types.OpenAIError {
|
||||
if openaiError.Error.Message == "" {
|
||||
return nil
|
||||
}
|
||||
return &openaiError.Error
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
if p.IsAzure {
|
||||
apiVersion := p.Channel.Other
|
||||
// 以-分割,检测modelName 最后一个元素是否为4位数字,必须是数字,如果是则删除modelName最后一个元素
|
||||
modelNameSlice := strings.Split(modelName, "-")
|
||||
lastModelNameSlice := modelNameSlice[len(modelNameSlice)-1]
|
||||
modelNum := common.String2Int(lastModelNameSlice)
|
||||
if modelNum > 999 && modelNum < 10000 {
|
||||
modelName = strings.TrimSuffix(modelName, "-"+lastModelNameSlice)
|
||||
}
|
||||
// 检测模型是是否包含 . 如果有则直接去掉
|
||||
modelName = strings.Replace(modelName, ".", "", -1)
|
||||
|
||||
if modelName == "dall-e-2" {
|
||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||
// 已经没有dall-e-2了,所以暂时写死
|
||||
@ -72,10 +103,6 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
}
|
||||
|
||||
// 检测模型是是否包含 . 如果有则直接去掉
|
||||
if strings.Contains(requestURL, ".") {
|
||||
requestURL = strings.Replace(requestURL, ".", "", -1)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
@ -102,89 +129,21 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = p.Context.Request.Body
|
||||
func (p *OpenAIProvider) GetRequestTextBody(relayMode int, ModelName string, request any) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(relayMode)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
return
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, ModelName)
|
||||
|
||||
// 发送流式请求
|
||||
func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||
defer req.Body.Close()
|
||||
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), ""
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
err := json.Unmarshal([]byte(data), response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
responseText += response.responseStreamHandler()
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
p.Context.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
return req, nil
|
||||
}
|
||||
|
@ -1,82 +1,101 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
type OpenAIStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
ModelName string
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 检测是否错误
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.ChatCompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := OpenAIStreamHandler{
|
||||
Usage: p.Usage,
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
}
|
||||
|
||||
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
// 如果等于 DONE 则结束
|
||||
if string(*rawLine) == "[DONE]" {
|
||||
*isFinished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var openaiResponse OpenAIProviderChatStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
if request.Stream {
|
||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderChatResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderChatResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderChatResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
@ -1,82 +1,96 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderCompletionResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 检测是否错误
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.CompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := OpenAIStreamHandler{
|
||||
Usage: p.Usage,
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
|
||||
}
|
||||
|
||||
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
// 如果等于 DONE 则结束
|
||||
if string(*rawLine) == "[DONE]" {
|
||||
*isFinished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var openaiResponse OpenAIProviderCompletionResponse
|
||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
|
||||
if request.Stream {
|
||||
// TODO
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
*response = append(*response, openaiResponse.CompletionResponse)
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderCompletionResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderCompletionResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
@ -6,40 +6,30 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true)
|
||||
func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderEmbeddingsResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = openAIProviderEmbeddingsResponse.Usage
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.EmbeddingResponse, nil
|
||||
}
|
||||
|
@ -5,28 +5,71 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
|
||||
func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderImageResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ImageResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) getRequestImageBody(relayMode int, ModelName string, request *types.ImageEditRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(relayMode)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, ModelName)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
// 创建请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if p.OriginalModel != request.Model {
|
||||
var formBody bytes.Buffer
|
||||
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(&formBody),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(p.Context.Request.Body),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
@ -34,22 +77,10 @@ func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isMod
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
|
||||
func imagesEditsMultipartForm(request *types.ImageEditRequest, b requester.FormBuilder) error {
|
||||
err := b.CreateFormFile("image", request.Image)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form image: %w", err)
|
||||
|
@ -6,53 +6,41 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
if !isWithinRange(request.Model, request.N) {
|
||||
func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
if !IsWithinRange(request.Model, request.N) {
|
||||
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderImageResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
// 检测是否错误
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ImageResponse, nil
|
||||
|
||||
}
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
func IsWithinRange(element string, value int) bool {
|
||||
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
||||
return false
|
||||
}
|
||||
|
@ -1,49 +1,35 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderImageResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ImageResponse, nil
|
||||
}
|
||||
|
@ -6,44 +6,31 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderModerationResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ModerationResponse, nil
|
||||
}
|
||||
|
@ -3,35 +3,29 @@ package openai
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
errWithCode = p.SendRequestRaw(req)
|
||||
func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
var resp *http.Response
|
||||
resp, errWithCode = p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
if resp.Header.Get("Content-Type") == "application/json" {
|
||||
return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler)
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
@ -4,48 +4,99 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderTranscriptionsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
return nil, nil
|
||||
defer req.Body.Close()
|
||||
|
||||
var textResponse string
|
||||
var resp *http.Response
|
||||
var err error
|
||||
audioResponseWrapper := &types.AudioResponseWrapper{}
|
||||
if hasJSONResponse(request) {
|
||||
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = openAIProviderTranscriptionsResponse.Text
|
||||
} else {
|
||||
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
audioResponseWrapper.Headers = map[string]string{
|
||||
"Content-Type": resp.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(textResponse, request.Model)
|
||||
|
||||
p.Usage.CompletionTokens = completionTokens
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
|
||||
|
||||
return audioResponseWrapper, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderTranscriptionsTextResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
return nil, nil
|
||||
func hasJSONResponse(request *types.AudioRequest) bool {
|
||||
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioTranscriptions, request.Model)
|
||||
func (p *OpenAIProvider) getRequestAudioBody(relayMode int, ModelName string, request *types.AudioRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(relayMode)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, ModelName)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
// 创建请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if p.OriginalModel != request.Model {
|
||||
var formBody bytes.Buffer
|
||||
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||
if err := audioMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(&formBody),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(p.Context.Request.Body),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
@ -53,37 +104,10 @@ func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isMod
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var textResponse string
|
||||
if hasJSONResponse(request) {
|
||||
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
textResponse = openAIProviderTranscriptionsResponse.Text
|
||||
} else {
|
||||
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(textResponse, request.Model)
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
return
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func hasJSONResponse(request *types.AudioRequest) bool {
|
||||
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
|
||||
}
|
||||
|
||||
func audioMultipartForm(request *types.AudioRequest, b common.FormBuilder) error {
|
||||
func audioMultipartForm(request *types.AudioRequest, b requester.FormBuilder) error {
|
||||
|
||||
err := b.CreateFormFile("file", request.File)
|
||||
if err != nil {
|
||||
|
@ -1,60 +1,53 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioTranslations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := audioMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
var textResponse string
|
||||
var resp *http.Response
|
||||
var err error
|
||||
audioResponseWrapper := &types.AudioResponseWrapper{}
|
||||
if hasJSONResponse(request) {
|
||||
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = openAIProviderTranscriptionsResponse.Text
|
||||
} else {
|
||||
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
audioResponseWrapper.Headers = map[string]string{
|
||||
"Content-Type": resp.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(textResponse, request.Model)
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
return
|
||||
|
||||
p.Usage.CompletionTokens = completionTokens
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
|
||||
|
||||
return audioResponseWrapper, nil
|
||||
}
|
||||
|
@ -12,11 +12,27 @@ type OpenAIProviderChatStreamResponse struct {
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type OpenAIProviderCompletionResponse struct {
|
||||
types.CompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) getResponseText() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type OpenAIProviderEmbeddingsResponse struct {
|
||||
types.EmbeddingResponse
|
||||
types.OpenAIErrorResponse
|
||||
@ -38,7 +54,7 @@ func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string {
|
||||
return (*string)(a)
|
||||
}
|
||||
|
||||
type OpenAIProviderImageResponseResponse struct {
|
||||
type OpenAIProviderImageResponse struct {
|
||||
types.ImageResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package openaisb
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
)
|
||||
@ -13,15 +12,14 @@ func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var response OpenAISBUsageResponse
|
||||
_, errWithCode := common.SendRequest(req, &response, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &response, false)
|
||||
if err != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
|
@ -1,18 +1,17 @@
|
||||
package openaisb
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/providers/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenaiSBProviderFactory struct{}
|
||||
|
||||
// 创建 OpenaiSBProvider
|
||||
func (f OpenaiSBProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f OpenaiSBProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &OpenaiSBProvider{
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
||||
OpenAIProvider: openai.CreateOpenAIProvider(channel, "https://api.openai-sb.com"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,22 +1,25 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PalmProviderFactory struct{}
|
||||
|
||||
// 创建 PalmProvider
|
||||
func (f PalmProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f PalmProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &PalmProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -25,6 +28,37 @@ type PalmProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var palmError *PaLMErrorResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(palmError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(palmError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(palmError *PaLMErrorResponse) *types.OpenAIError {
|
||||
if palmError.Error.Code == 0 {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: palmError.Error.Message,
|
||||
Type: "palm_error",
|
||||
Param: palmError.Error.Status,
|
||||
Code: palmError.Error.Code,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
@ -3,30 +3,98 @@ package palm
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
Code: palmResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
type palmStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
palmResponse := &PaLMChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, palmResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)),
|
||||
return p.convertToChatOpenai(palmResponse, request)
|
||||
}
|
||||
|
||||
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
for i, candidate := range palmResponse.Candidates {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &palmStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
palmRequest := convertFromChatOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(palmRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (p *PalmProvider) convertToChatOpenai(response *PaLMChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.PaLMErrorResponse)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
||||
Model: request.Model,
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: i,
|
||||
Message: types.ChatCompletionMessage{
|
||||
@ -35,20 +103,21 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model)
|
||||
palmResponse.Usage.CompletionTokens = completionTokens
|
||||
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
|
||||
completionTokens := common.CountTokenText(response.Candidates[0].Content, request.Model)
|
||||
response.Usage.CompletionTokens = completionTokens
|
||||
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
|
||||
|
||||
fullTextResponse.Usage = palmResponse.Usage
|
||||
fullTextResponse.Model = palmResponse.Model
|
||||
openaiResponse.Usage = response.Usage
|
||||
|
||||
return fullTextResponse, nil
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest {
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(request.Messages)),
|
||||
@ -72,132 +141,51 @@ func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func (p *PalmProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
var palmChatResponse PaLMChatResponse
|
||||
err := json.Unmarshal(*rawLine, &palmChatResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(responseText, request.Model),
|
||||
}
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
var palmChatResponse = &PaLMChatResponse{
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
}
|
||||
errWithCode = p.SendRequest(req, palmChatResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = palmChatResponse.Usage
|
||||
error := errorHandle(&palmChatResponse.PaLMErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
return
|
||||
|
||||
return h.convertToOpenaiStream(&palmChatResponse, response)
|
||||
|
||||
}
|
||||
|
||||
func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse {
|
||||
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
if len(palmChatResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmChatResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
var response types.ChatCompletionStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse)
|
||||
fullTextResponse.ID = responseId
|
||||
fullTextResponse.Created = createdTime
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
responseText = palmResponse.Candidates[0].Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
choice.FinishReason = types.FinishReasonStop
|
||||
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
Created: common.GetTimestamp(),
|
||||
}
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -30,11 +30,15 @@ type PaLMError struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMErrorResponse struct {
|
||||
Error PaLMError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []types.ChatCompletionMessage `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
PaLMErrorResponse
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ import (
|
||||
|
||||
// 定义供应商工厂接口
|
||||
type ProviderFactory interface {
|
||||
Create(c *gin.Context) base.ProviderInterface
|
||||
Create(Channel *model.Channel) base.ProviderInterface
|
||||
}
|
||||
|
||||
// 创建全局的供应商工厂映射
|
||||
@ -71,11 +71,11 @@ func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface
|
||||
return nil
|
||||
}
|
||||
|
||||
provider = openai.CreateOpenAIProvider(c, baseURL)
|
||||
provider = openai.CreateOpenAIProvider(channel, baseURL)
|
||||
} else {
|
||||
provider = factory.Create(c)
|
||||
provider = factory.Create(channel)
|
||||
}
|
||||
provider.SetChannel(channel)
|
||||
provider.SetContext(c)
|
||||
|
||||
return provider
|
||||
}
|
||||
|
@ -4,25 +4,28 @@ import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TencentProviderFactory struct{}
|
||||
|
||||
// 创建 TencentProvider
|
||||
func (f TencentProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f TencentProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &TencentProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://hunyuan.cloud.tencent.com",
|
||||
ChatCompletions: "/hyllm/v1/chat/completions",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -31,6 +34,36 @@ type TencentProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://hunyuan.cloud.tencent.com",
|
||||
ChatCompletions: "/hyllm/v1/chat/completions",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var tencentError *TencentResponseError
|
||||
err := json.NewDecoder(resp.Body).Decode(tencentError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(tencentError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(tencentError *TencentResponseError) *types.OpenAIError {
|
||||
if tencentError.Error.Code == 0 {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: tencentError.Error.Message,
|
||||
Type: "tencent_error",
|
||||
Code: tencentError.Error.Code,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
@ -77,7 +110,7 @@ func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Sort(sort.StringSlice(params))
|
||||
sort.Strings(params)
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
|
@ -1,50 +1,127 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
type tencentStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
tencentChatResponse := &TencentChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, tencentChatResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
return p.convertToChatOpenai(tencentChatResponse, request)
|
||||
}
|
||||
|
||||
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &tencentStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
tencentRequest := convertFromChatOpenai(request)
|
||||
|
||||
sign := p.getTencentSign(*tencentRequest)
|
||||
if sign == "" {
|
||||
return nil, common.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
headers["Authorization"] = sign
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(tencentRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (p *TencentProvider) convertToChatOpenai(response *TencentChatResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(&response.TencentResponseError)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: TencentResponse.Usage,
|
||||
Model: TencentResponse.Model,
|
||||
Usage: response.Usage,
|
||||
Model: request.Model,
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
if len(response.Choices) > 0 {
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: TencentResponse.Choices[0].Messages.Content,
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: TencentResponse.Choices[0].FinishReason,
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
||||
}
|
||||
|
||||
return fullTextResponse, nil
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest {
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatRequest {
|
||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
@ -79,143 +156,51 @@ func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionReques
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
sign := p.getTencentSign(*requestBody)
|
||||
if sign == "" {
|
||||
return nil, common.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
headers["Authorization"] = sign
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
}
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[5:]
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
var tencentChatResponse TencentChatResponse
|
||||
err := json.Unmarshal(*rawLine, &tencentChatResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(responseText, request.Model),
|
||||
}
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
tencentResponse := &TencentChatResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, tencentResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = tencentResponse.Usage
|
||||
error := errorHandle(&tencentChatResponse.TencentResponseError)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
return
|
||||
|
||||
return h.convertToOpenaiStream(&tencentChatResponse, response)
|
||||
|
||||
}
|
||||
|
||||
func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse {
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: TencentResponse.Model,
|
||||
Model: h.Request.Model,
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
if len(tencentChatResponse.Choices) > 0 {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
choice.Delta.Content = tencentChatResponse.Choices[0].Delta.Content
|
||||
if tencentChatResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = types.FinishReasonStop
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
streamResponse.Choices = append(streamResponse.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
defer req.Body.Close()
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), ""
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
TencentResponse.Model = model
|
||||
response := p.streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, responseText
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -50,13 +50,17 @@ type TencentResponseChoices struct {
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentResponseError struct {
|
||||
Error TencentError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
Model string `json:"model,omitempty"` // 模型名称
|
||||
TencentResponseError
|
||||
}
|
||||
|
@ -7,31 +7,53 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type XunfeiProviderFactory struct{}
|
||||
|
||||
// 创建 XunfeiProvider
|
||||
func (f XunfeiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f XunfeiProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &XunfeiProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "wss://spark-api.xf-yun.com",
|
||||
ChatCompletions: "true",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, nil),
|
||||
},
|
||||
wsRequester: requester.NewWSRequester(channel.Proxy),
|
||||
}
|
||||
}
|
||||
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
type XunfeiProvider struct {
|
||||
base.BaseProvider
|
||||
domain string
|
||||
apiId string
|
||||
domain string
|
||||
apiId string
|
||||
wsRequester *requester.WSRequester
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "wss://spark-api.xf-yun.com",
|
||||
ChatCompletions: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(xunfeiError *XunfeiChatResponse) *types.OpenAIError {
|
||||
if xunfeiError.Header.Code == 0 {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: xunfeiError.Header.Message,
|
||||
Type: "xunfei_error",
|
||||
Code: xunfeiError.Header.Code,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
@ -68,7 +90,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri
|
||||
if apiVersion != "v1.1" {
|
||||
domain += strings.Split(apiVersion, ".")[0]
|
||||
}
|
||||
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret)
|
||||
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.Config.BaseURL, apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
|
||||
|
@ -2,140 +2,93 @@ package xunfei
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"time"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
|
||||
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate())
|
||||
} else {
|
||||
return p.sendRequest(dataChan, stopChan, request.GetFunctionCate())
|
||||
}
|
||||
type xunfeiHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
|
||||
if xunfeiResponse.Header.Code != 0 {
|
||||
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||
}
|
||||
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = p.Context.Writer.Write(jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
usage = &types.Usage{}
|
||||
|
||||
// 等待第一个dataChan的响应
|
||||
xunfeiResponse, ok := <-dataChan
|
||||
if !ok {
|
||||
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError)
|
||||
}
|
||||
if xunfeiResponse.Header.Code != 0 {
|
||||
errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
|
||||
func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
wsConn, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 如果第一个响应没有错误,设置StreamHeaders并开始streaming
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
// 处理第一个响应
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
xunfeiRequest := p.convertFromChatOpenai(request)
|
||||
|
||||
chatHandler := &xunfeiHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
stream, errWithCode := requester.SendWSJsonRequest[XunfeiChatResponse](wsConn, xunfeiRequest, chatHandler.handlerNotStream)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return chatHandler.convertToChatOpenai(stream)
|
||||
|
||||
// 处理后续的响应
|
||||
for {
|
||||
select {
|
||||
case xunfeiResponse, ok := <-dataChan:
|
||||
if !ok {
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
}
|
||||
})
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
|
||||
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
wsConn, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
xunfeiRequest := p.convertFromChatOpenai(request)
|
||||
|
||||
chatHandler := &xunfeiHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
authUrl := p.GetFullRequestURL(url, request.Model)
|
||||
|
||||
wsConn, err := p.wsRequester.NewRequest(authUrl, nil)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return wsConn, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *XunfeiChatRequest {
|
||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Role: types.ChatMessageRoleUser,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "assistant",
|
||||
Role: types.ChatMessageRoleAssistant,
|
||||
Content: "Okay",
|
||||
})
|
||||
} else if message.Role == types.ChatMessageRoleFunction {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: types.ChatMessageRoleUser,
|
||||
Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(),
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: message.Role,
|
||||
@ -143,6 +96,7 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
xunfeiRequest := XunfeiChatRequest{}
|
||||
|
||||
if request.Tools != nil {
|
||||
@ -166,35 +120,57 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) && response == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if len((*response)[0].Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
xunfeiResponse = (*response)[0]
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
}
|
||||
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
choice := types.ChatCompletionChoice{
|
||||
Index: 0,
|
||||
FinishReason: base.StopFinishReason,
|
||||
FinishReason: types.FinishReasonStop,
|
||||
}
|
||||
|
||||
xunfeiText := response.Payload.Choices.Text[0]
|
||||
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
|
||||
|
||||
if xunfeiText.FunctionCall != nil {
|
||||
choice.Message = types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
}
|
||||
|
||||
if functionCate == "tool" {
|
||||
if h.Request.Tools != nil {
|
||||
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||
{
|
||||
Id: response.Header.Sid,
|
||||
Id: xunfeiResponse.Header.Sid,
|
||||
Type: "function",
|
||||
Function: *xunfeiText.FunctionCall,
|
||||
Function: xunfeiText.FunctionCall,
|
||||
},
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||
choice.FinishReason = types.FinishReasonToolCalls
|
||||
} else {
|
||||
choice.Message.FunctionCall = xunfeiText.FunctionCall
|
||||
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||
choice.FinishReason = types.FinishReasonFunctionCall
|
||||
}
|
||||
|
||||
} else {
|
||||
@ -204,97 +180,128 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, fun
|
||||
}
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: response.Header.Sid,
|
||||
fullTextResponse := &types.ChatCompletionResponse{
|
||||
ID: xunfeiResponse.Header.Sid,
|
||||
Object: "chat.completion",
|
||||
Model: "SparkDesk",
|
||||
Model: h.Request.Model,
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Usage: &response.Payload.Usage.Text,
|
||||
Usage: &xunfeiResponse.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "{") {
|
||||
*rawLine = nil
|
||||
return nil, nil
|
||||
}
|
||||
conn, resp, err := d.Dial(authUrl, nil)
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := p.requestOpenAI2Xunfei(textRequest)
|
||||
err = conn.WriteJSON(data)
|
||||
|
||||
var xunfeiChatResponse XunfeiChatResponse
|
||||
err := json.Unmarshal(*rawLine, &xunfeiChatResponse)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response XunfeiChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
error := errorHandle(&xunfeiChatResponse)
|
||||
if error != nil {
|
||||
return nil, error
|
||||
}
|
||||
|
||||
return dataChan, stopChan, nil
|
||||
if xunfeiChatResponse.Payload.Choices.Status == 2 {
|
||||
*isFinished = true
|
||||
}
|
||||
|
||||
h.Usage.PromptTokens = xunfeiChatResponse.Payload.Usage.Text.PromptTokens
|
||||
h.Usage.CompletionTokens = xunfeiChatResponse.Payload.Usage.Text.CompletionTokens
|
||||
h.Usage.TotalTokens = xunfeiChatResponse.Payload.Usage.Text.TotalTokens
|
||||
|
||||
return &xunfeiChatResponse, nil
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error {
|
||||
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
|
||||
|
||||
if *rawLine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
*response = append(*response, *xunfeiChatResponse)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *rawLine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(xunfeiChatResponse, response)
|
||||
}
|
||||
|
||||
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||
}
|
||||
|
||||
choice := types.ChatCompletionStreamChoice{
|
||||
Index: 0,
|
||||
Delta: types.ChatCompletionStreamChoiceDelta{
|
||||
Role: types.ChatMessageRoleAssistant,
|
||||
},
|
||||
}
|
||||
xunfeiText := xunfeiChatResponse.Payload.Choices.Text[0]
|
||||
|
||||
if xunfeiText.FunctionCall != nil {
|
||||
if functionCate == "tool" {
|
||||
if h.Request.Tools != nil {
|
||||
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||
{
|
||||
Id: xunfeiResponse.Header.Sid,
|
||||
Id: xunfeiChatResponse.Header.Sid,
|
||||
Index: 0,
|
||||
Type: "function",
|
||||
Function: *xunfeiText.FunctionCall,
|
||||
Function: xunfeiText.FunctionCall,
|
||||
},
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||
choice.FinishReason = types.FinishReasonToolCalls
|
||||
} else {
|
||||
choice.Delta.FunctionCall = xunfeiText.FunctionCall
|
||||
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||
choice.FinishReason = types.FinishReasonFunctionCall
|
||||
}
|
||||
|
||||
} else {
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
choice.Delta.Content = xunfeiChatResponse.Payload.Choices.Text[0].Content
|
||||
if xunfeiChatResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = types.FinishReasonStop
|
||||
}
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: xunfeiResponse.Header.Sid,
|
||||
chatCompletion := types.ChatCompletionStreamResponse{
|
||||
ID: xunfeiChatResponse.Header.Sid,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
Model: h.Request.Model,
|
||||
}
|
||||
return &response
|
||||
|
||||
if xunfeiText.FunctionCall == nil {
|
||||
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletion)
|
||||
} else {
|
||||
choices := choice.ConvertOpenaiStream()
|
||||
for _, choice := range choices {
|
||||
chatCompletionCopy := chatCompletion
|
||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletionCopy)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,14 +1,18 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
@ -18,12 +22,12 @@ var expSeconds int64 = 24 * 3600
|
||||
type ZhipuProviderFactory struct{}
|
||||
|
||||
// 创建 ZhipuProvider
|
||||
func (f ZhipuProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
func (f ZhipuProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &ZhipuProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://open.bigmodel.cn",
|
||||
ChatCompletions: "/api/paas/v3/model-api",
|
||||
Context: c,
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -32,6 +36,36 @@ type ZhipuProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://open.bigmodel.cn",
|
||||
ChatCompletions: "/api/paas/v3/model-api",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var zhipuError *ZhipuResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(zhipuError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(zhipuError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(zhipuError *ZhipuResponse) *types.OpenAIError {
|
||||
if zhipuError.Success {
|
||||
return nil
|
||||
}
|
||||
return &types.OpenAIError{
|
||||
Message: zhipuError.Msg,
|
||||
Type: "zhipu_error",
|
||||
Code: zhipuError.Code,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
|
@ -1,38 +1,108 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if !zhipuResponse.Success {
|
||||
return &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: types.OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
Code: zhipuResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
type zhipuStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
Request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
zhipuChatResponse := &ZhipuResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, zhipuChatResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
fullTextResponse := types.ChatCompletionResponse{
|
||||
ID: zhipuResponse.Data.TaskId,
|
||||
return p.convertToChatOpenai(zhipuChatResponse, request)
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := &zhipuStreamHandler{
|
||||
Usage: p.Usage,
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, request.Model)
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
fullRequestURL += "/sse-invoke"
|
||||
} else {
|
||||
fullRequestURL += "/invoke"
|
||||
}
|
||||
|
||||
zhipuRequest := convertFromChatOpenai(request)
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(zhipuRequest), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
error := errorHandle(response)
|
||||
if error != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *error,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
openaiResponse = &types.ChatCompletionResponse{
|
||||
ID: response.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: zhipuResponse.Model,
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
|
||||
Usage: &zhipuResponse.Data.Usage,
|
||||
Model: request.Model,
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Data.Choices)),
|
||||
Usage: &response.Data.Usage,
|
||||
}
|
||||
for i, choice := range zhipuResponse.Data.Choices {
|
||||
for i, choice := range response.Data.Choices {
|
||||
openaiChoice := types.ChatCompletionChoice{
|
||||
Index: i,
|
||||
Message: types.ChatCompletionMessage{
|
||||
@ -41,16 +111,18 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
},
|
||||
FinishReason: "",
|
||||
}
|
||||
if i == len(zhipuResponse.Data.Choices)-1 {
|
||||
if i == len(response.Data.Choices)-1 {
|
||||
openaiChoice.FinishReason = "stop"
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
openaiResponse.Choices = append(openaiResponse.Choices, openaiChoice)
|
||||
}
|
||||
return fullTextResponse, nil
|
||||
|
||||
*p.Usage = response.Data.Usage
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest {
|
||||
func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
|
||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
@ -77,158 +149,60 @@ func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody := p.getChatRequestBody(request)
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
fullRequestURL += "/sse-invoke"
|
||||
} else {
|
||||
fullRequestURL += "/invoke"
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data:") && !strings.HasPrefix(string(*rawLine), "meta:") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
if strings.HasPrefix(string(*rawLine), "meta:") {
|
||||
*rawLine = (*rawLine)[5:]
|
||||
var zhipuStreamMetaResponse ZhipuStreamMetaResponse
|
||||
err := json.Unmarshal(*rawLine, &zhipuStreamMetaResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
*isFinished = true
|
||||
return h.handlerMeta(&zhipuStreamMetaResponse, response)
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
errWithCode, usage = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
zhipuResponse := &ZhipuResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, zhipuResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &zhipuResponse.Data.Usage
|
||||
}
|
||||
return
|
||||
|
||||
*rawLine = (*rawLine)[5:]
|
||||
return h.convertToOpenaiStream(string(*rawLine), response)
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse {
|
||||
func (h *zhipuStreamHandler) convertToOpenaiStream(content string, response *[]types.ChatCompletionStreamResponse) error {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
choice.Delta.Content = content
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
|
||||
func (h *zhipuStreamHandler) handlerMeta(zhipuResponse *ZhipuStreamMetaResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
choice.FinishReason = types.FinishReasonStop
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: zhipuResponse.Model,
|
||||
Model: h.Request.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), nil
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var usage *types.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
metaChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
lines := strings.Split(data, "\n")
|
||||
for i, line := range lines {
|
||||
if len(line) < 5 {
|
||||
continue
|
||||
}
|
||||
if line[:5] == "data:" {
|
||||
dataChan <- line[5:]
|
||||
if i != len(lines)-1 {
|
||||
dataChan <- "\n"
|
||||
}
|
||||
} else if line[:5] == "meta:" {
|
||||
metaChan <- line[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
response := p.streamResponseZhipu2OpenAI(data)
|
||||
response.Model = model
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case data := <-metaChan:
|
||||
var zhipuResponse ZhipuStreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
zhipuResponse.Model = model
|
||||
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
return nil, usage
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
*h.Usage = zhipuResponse.Usage
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -26,3 +26,8 @@ type AudioResponse struct {
|
||||
Segments any `json:"segments,omitempty"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type AudioResponseWrapper struct {
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
}
|
||||
|
119
types/chat.go
119
types/chat.go
@ -5,16 +5,33 @@ const (
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
||||
|
||||
const (
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonLength = "length"
|
||||
FinishReasonFunctionCall = "function_call"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonContentFilter = "content_filter"
|
||||
FinishReasonNull = "null"
|
||||
)
|
||||
|
||||
const (
|
||||
ChatMessageRoleSystem = "system"
|
||||
ChatMessageRoleUser = "user"
|
||||
ChatMessageRoleAssistant = "assistant"
|
||||
ChatMessageRoleFunction = "function"
|
||||
ChatMessageRoleTool = "tool"
|
||||
)
|
||||
|
||||
type ChatCompletionToolCallsFunction struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type ChatCompletionToolCalls struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function ChatCompletionToolCallsFunction `json:"function"`
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *ChatCompletionToolCallsFunction `json:"function"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
@ -123,6 +140,8 @@ type ChatCompletionRequest struct {
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias any `json:"logit_bias,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Functions []*ChatCompletionFunction `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
@ -151,19 +170,89 @@ type ChatCompletionTool struct {
|
||||
}
|
||||
|
||||
type ChatCompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatCompletionMessage `json:"message"`
|
||||
FinishReason any `json:"finish_reason,omitempty"`
|
||||
Index int `json:"index"`
|
||||
Message ChatCompletionMessage `json:"message"`
|
||||
FinishReason any `json:"finish_reason,omitempty"`
|
||||
ContentFilterResults any `json:"content_filter_results,omitempty"`
|
||||
FinishDetails any `json:"finish_details,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
PromptFilterResults any `json:"prompt_filter_results,omitempty"`
|
||||
}
|
||||
|
||||
func (c ChatCompletionStreamChoice) ConvertOpenaiStream() []ChatCompletionStreamChoice {
|
||||
var function *ChatCompletionToolCallsFunction
|
||||
var functions []*ChatCompletionToolCallsFunction
|
||||
var choices []ChatCompletionStreamChoice
|
||||
var stopFinish string
|
||||
if c.Delta.FunctionCall != nil {
|
||||
function = c.Delta.FunctionCall
|
||||
stopFinish = FinishReasonFunctionCall
|
||||
} else {
|
||||
function = c.Delta.ToolCalls[0].Function
|
||||
stopFinish = FinishReasonToolCalls
|
||||
}
|
||||
|
||||
if function.Name == "" {
|
||||
c.FinishReason = stopFinish
|
||||
choices = append(choices, c)
|
||||
return choices
|
||||
}
|
||||
|
||||
functions = append(functions, &ChatCompletionToolCallsFunction{
|
||||
Name: function.Name,
|
||||
Arguments: "",
|
||||
})
|
||||
|
||||
if function.Arguments == "" || function.Arguments == "{}" {
|
||||
functions = append(functions, &ChatCompletionToolCallsFunction{
|
||||
Arguments: "{}",
|
||||
})
|
||||
} else {
|
||||
functions = append(functions, &ChatCompletionToolCallsFunction{
|
||||
Arguments: function.Arguments,
|
||||
})
|
||||
}
|
||||
|
||||
// 循环functions, 生成choices
|
||||
for _, function := range functions {
|
||||
choice := ChatCompletionStreamChoice{
|
||||
Index: 0,
|
||||
Delta: ChatCompletionStreamChoiceDelta{
|
||||
Role: c.Delta.Role,
|
||||
},
|
||||
}
|
||||
if stopFinish == FinishReasonFunctionCall {
|
||||
choice.Delta.FunctionCall = function
|
||||
} else {
|
||||
choice.Delta.ToolCalls = []*ChatCompletionToolCalls{
|
||||
{
|
||||
Id: c.Delta.ToolCalls[0].Id,
|
||||
Index: 0,
|
||||
Type: "function",
|
||||
Function: function,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
choices = append(choices, choice)
|
||||
}
|
||||
|
||||
choices = append(choices, ChatCompletionStreamChoice{
|
||||
Index: c.Index,
|
||||
Delta: ChatCompletionStreamChoiceDelta{},
|
||||
FinishReason: stopFinish,
|
||||
})
|
||||
|
||||
return choices
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoiceDelta struct {
|
||||
|
@ -1,5 +1,7 @@
|
||||
package types
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
@ -14,6 +16,16 @@ type OpenAIError struct {
|
||||
InnerError any `json:"innererror,omitempty"`
|
||||
}
|
||||
|
||||
func (e *OpenAIError) Error() string {
|
||||
response := &OpenAIErrorResponse{
|
||||
Error: *e,
|
||||
}
|
||||
|
||||
// 转换为JSON
|
||||
bytes, _ := json.Marshal(response)
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
OpenAIError
|
||||
StatusCode int `json:"status_code"`
|
||||
|
@ -8,9 +8,9 @@ type EmbeddingRequest struct {
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
Embedding any `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
|
Loading…
Reference in New Issue
Block a user