♻️ refactor: split relay
This commit is contained in:
parent
53da7134b2
commit
902c2faa2c
126
common/client.go
Normal file
126
common/client.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var HttpClient *http.Client
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if RelayTimeout == 0 {
|
||||||
|
HttpClient = &http.Client{}
|
||||||
|
} else {
|
||||||
|
HttpClient = &http.Client{
|
||||||
|
Timeout: time.Duration(RelayTimeout) * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
requestBuilder RequestBuilder
|
||||||
|
createFormBuilder func(io.Writer) FormBuilder
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient() *Client {
|
||||||
|
return &Client{
|
||||||
|
requestBuilder: NewRequestBuilder(),
|
||||||
|
createFormBuilder: func(body io.Writer) FormBuilder {
|
||||||
|
return NewFormBuilder(body)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestOptions struct {
|
||||||
|
body any
|
||||||
|
header http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestOption func(*requestOptions)
|
||||||
|
|
||||||
|
func WithBody(body any) requestOption {
|
||||||
|
return func(args *requestOptions) {
|
||||||
|
args.body = body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHeader(header map[string]string) requestOption {
|
||||||
|
return func(args *requestOptions) {
|
||||||
|
for k, v := range header {
|
||||||
|
args.header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestError struct {
|
||||||
|
HTTPStatusCode int
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) NewRequest(method, url string, setters ...requestOption) (*http.Request, error) {
|
||||||
|
// Default Options
|
||||||
|
args := &requestOptions{
|
||||||
|
body: nil,
|
||||||
|
header: make(http.Header),
|
||||||
|
}
|
||||||
|
for _, setter := range setters {
|
||||||
|
setter(args)
|
||||||
|
}
|
||||||
|
req, err := c.requestBuilder.Build(method, url, args.body, args.header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SendRequest(req *http.Request, response any) error {
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 处理响应
|
||||||
|
if IsFailureStatusCode(resp) {
|
||||||
|
return fmt.Errorf("status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
err = DecodeResponse(resp.Body, response)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsFailureStatusCode(resp *http.Response) bool {
|
||||||
|
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeResponse(body io.Reader, v any) error {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if result, ok := v.(*string); ok {
|
||||||
|
return DecodeString(body, result)
|
||||||
|
}
|
||||||
|
return json.NewDecoder(body).Decode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeString(body io.Reader, output *string) error {
|
||||||
|
b, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*output = string(b)
|
||||||
|
return nil
|
||||||
|
}
|
65
common/form_builder.go
Normal file
65
common/form_builder.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FormBuilder interface {
|
||||||
|
CreateFormFile(fieldname string, file *os.File) error
|
||||||
|
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
|
||||||
|
WriteField(fieldname, value string) error
|
||||||
|
Close() error
|
||||||
|
FormDataContentType() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultFormBuilder struct {
|
||||||
|
writer *multipart.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
|
||||||
|
return &DefaultFormBuilder{
|
||||||
|
writer: multipart.NewWriter(body),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||||
|
return fb.createFormFile(fieldname, file, file.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||||
|
return fb.createFormFile(fieldname, r, path.Base(filename))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||||
|
if filename == "" {
|
||||||
|
return fmt.Errorf("filename cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = io.Copy(fieldWriter, r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
|
||||||
|
return fb.writer.WriteField(fieldname, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) Close() error {
|
||||||
|
return fb.writer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) FormDataContentType() string {
|
||||||
|
return fb.writer.FormDataContentType()
|
||||||
|
}
|
15
common/marshaller.go
Normal file
15
common/marshaller.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Marshaller interface {
|
||||||
|
Marshal(value any) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type JSONMarshaller struct{}
|
||||||
|
|
||||||
|
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
|
||||||
|
return json.Marshal(value)
|
||||||
|
}
|
59
common/quota.go
Normal file
59
common/quota.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
// type Quota struct {
|
||||||
|
// ModelName string
|
||||||
|
// ModelRatio float64
|
||||||
|
// GroupRatio float64
|
||||||
|
// Ratio float64
|
||||||
|
// UserQuota int
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func CreateQuota(modelName string, userQuota int, group string) *Quota {
|
||||||
|
// modelRatio := GetModelRatio(modelName)
|
||||||
|
// groupRatio := GetGroupRatio(group)
|
||||||
|
|
||||||
|
// return &Quota{
|
||||||
|
// ModelName: modelName,
|
||||||
|
// ModelRatio: modelRatio,
|
||||||
|
// GroupRatio: groupRatio,
|
||||||
|
// Ratio: modelRatio * groupRatio,
|
||||||
|
// UserQuota: userQuota,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (q *Quota) getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||||
|
// if ApproximateTokenEnabled {
|
||||||
|
// return int(float64(len(text)) * 0.38)
|
||||||
|
// }
|
||||||
|
// return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (q *Quota) CountTokenMessages(messages []Message, model string) int {
|
||||||
|
// tokenEncoder := q.getTokenEncoder(model)
|
||||||
|
// // Reference:
|
||||||
|
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
|
// //
|
||||||
|
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
// var tokensPerMessage int
|
||||||
|
// var tokensPerName int
|
||||||
|
// if model == "gpt-3.5-turbo-0301" {
|
||||||
|
// tokensPerMessage = 4
|
||||||
|
// tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
|
// } else {
|
||||||
|
// tokensPerMessage = 3
|
||||||
|
// tokensPerName = 1
|
||||||
|
// }
|
||||||
|
// tokenNum := 0
|
||||||
|
// for _, message := range messages {
|
||||||
|
// tokenNum += tokensPerMessage
|
||||||
|
// tokenNum += q.getTokenNum(tokenEncoder, message.StringContent())
|
||||||
|
// tokenNum += q.getTokenNum(tokenEncoder, message.Role)
|
||||||
|
// if message.Name != nil {
|
||||||
|
// tokenNum += tokensPerName
|
||||||
|
// tokenNum += q.getTokenNum(tokenEncoder, *message.Name)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
|
// return tokenNum
|
||||||
|
// }
|
50
common/request_builder.go
Normal file
50
common/request_builder.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestBuilder interface {
|
||||||
|
Build(method, url string, body any, header http.Header) (*http.Request, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPRequestBuilder struct {
|
||||||
|
marshaller Marshaller
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestBuilder() *HTTPRequestBuilder {
|
||||||
|
return &HTTPRequestBuilder{
|
||||||
|
marshaller: &JSONMarshaller{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *HTTPRequestBuilder) Build(
|
||||||
|
method string,
|
||||||
|
url string,
|
||||||
|
body any,
|
||||||
|
header http.Header,
|
||||||
|
) (req *http.Request, err error) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if body != nil {
|
||||||
|
if v, ok := body.(io.Reader); ok {
|
||||||
|
bodyReader = v
|
||||||
|
} else {
|
||||||
|
var reqBytes []byte
|
||||||
|
reqBytes, err = b.marshaller.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bodyReader = bytes.NewBuffer(reqBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req, err = http.NewRequest(method, url, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header != nil {
|
||||||
|
req.Header = header
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
109
common/token.go
Normal file
109
common/token.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
|
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||||
|
|
||||||
|
func InitTokenEncoders() {
|
||||||
|
SysLog("initializing token encoders")
|
||||||
|
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||||
|
if err != nil {
|
||||||
|
FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defaultTokenEncoder = gpt35TokenEncoder
|
||||||
|
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||||
|
if err != nil {
|
||||||
|
FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||||
|
}
|
||||||
|
for model, _ := range ModelRatio {
|
||||||
|
if strings.HasPrefix(model, "gpt-3.5") {
|
||||||
|
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||||
|
} else if strings.HasPrefix(model, "gpt-4") {
|
||||||
|
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||||
|
} else {
|
||||||
|
tokenEncoderMap[model] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SysLog("token encoders initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
|
tokenEncoder, ok := tokenEncoderMap[model]
|
||||||
|
if ok && tokenEncoder != nil {
|
||||||
|
return tokenEncoder
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||||
|
tokenEncoder = defaultTokenEncoder
|
||||||
|
}
|
||||||
|
tokenEncoderMap[model] = tokenEncoder
|
||||||
|
return tokenEncoder
|
||||||
|
}
|
||||||
|
return defaultTokenEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||||
|
if ApproximateTokenEnabled {
|
||||||
|
return int(float64(len(text)) * 0.38)
|
||||||
|
}
|
||||||
|
return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountTokenMessages(messages []types.ChatCompletionMessage, model string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
// Reference:
|
||||||
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
|
//
|
||||||
|
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
var tokensPerMessage int
|
||||||
|
var tokensPerName int
|
||||||
|
if model == "gpt-3.5-turbo-0301" {
|
||||||
|
tokensPerMessage = 4
|
||||||
|
tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
|
} else {
|
||||||
|
tokensPerMessage = 3
|
||||||
|
tokensPerName = 1
|
||||||
|
}
|
||||||
|
tokenNum := 0
|
||||||
|
for _, message := range messages {
|
||||||
|
tokenNum += tokensPerMessage
|
||||||
|
tokenNum += getTokenNum(tokenEncoder, message.StringContent())
|
||||||
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
|
if message.Name != nil {
|
||||||
|
tokenNum += tokensPerName
|
||||||
|
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return tokenNum
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountTokenInput(input any, model string) int {
|
||||||
|
switch input.(type) {
|
||||||
|
case string:
|
||||||
|
return CountTokenText(input.(string), model)
|
||||||
|
case []string:
|
||||||
|
text := ""
|
||||||
|
for _, s := range input.([]string) {
|
||||||
|
text += s
|
||||||
|
}
|
||||||
|
return CountTokenText(text, model)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountTokenText(text string, model string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
return getTokenNum(tokenEncoder, text)
|
||||||
|
}
|
@ -92,7 +92,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := httpClient.Do(req)
|
res, err := common.HttpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -204,6 +204,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
if channel.GetBaseURL() == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = &baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case common.ChannelTypeOpenAI:
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -16,86 +15,81 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) {
|
||||||
|
// 创建一个 http.Request
|
||||||
|
req, err := http.NewRequest("POST", "/v1/chat/completions", nil)
|
||||||
|
if err != nil {
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = req
|
||||||
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("channel_id", channel.Id)
|
||||||
|
c.Set("channel_name", channel.Name)
|
||||||
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
c.Set("api_key", channel.Key)
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypePaLM:
|
case common.ChannelTypePaLM:
|
||||||
fallthrough
|
request.Model = "PaLM-2"
|
||||||
case common.ChannelTypeAnthropic:
|
case common.ChannelTypeAnthropic:
|
||||||
fallthrough
|
request.Model = "claude-2"
|
||||||
case common.ChannelTypeBaidu:
|
case common.ChannelTypeBaidu:
|
||||||
fallthrough
|
request.Model = "ERNIE-Bot"
|
||||||
case common.ChannelTypeZhipu:
|
case common.ChannelTypeZhipu:
|
||||||
fallthrough
|
request.Model = "chatglm_lite"
|
||||||
case common.ChannelTypeAli:
|
case common.ChannelTypeAli:
|
||||||
fallthrough
|
request.Model = "qwen-turbo"
|
||||||
case common.ChannelType360:
|
case common.ChannelType360:
|
||||||
fallthrough
|
request.Model = "360GPT_S2_V9"
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
request.Model = "SparkDesk"
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeTencent:
|
||||||
|
request.Model = "hunyuan"
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
request.Model = "gpt-35-turbo"
|
request.Model = "gpt-3.5-turbo"
|
||||||
defer func() {
|
c.Set("api_version", channel.Other)
|
||||||
if err != nil {
|
|
||||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
default:
|
default:
|
||||||
request.Model = "gpt-3.5-turbo"
|
request.Model = "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
|
||||||
} else {
|
|
||||||
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
|
||||||
requestURL = baseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
|
chatProvider := GetChatProvider(channel.Type, c)
|
||||||
}
|
isModelMapped := false
|
||||||
jsonData, err := json.Marshal(request)
|
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
if modelMap != nil && modelMap[request.Model] != "" {
|
||||||
if err != nil {
|
request.Model = modelMap[request.Model]
|
||||||
return err, nil
|
isModelMapped = true
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
|
||||||
req.Header.Set("api-key", channel.Key)
|
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||||
} else {
|
_, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&request, isModelMapped, promptTokens)
|
||||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
if openAIErrorWithStatusCode != nil {
|
||||||
}
|
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
var response TextResponse
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(body, &response)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
|
|
||||||
}
|
|
||||||
if response.Usage.CompletionTokens == 0 {
|
|
||||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest() *ChatRequest {
|
func buildTestRequest() *types.ChatCompletionRequest {
|
||||||
testRequest := &ChatRequest{
|
testRequest := &types.ChatCompletionRequest{
|
||||||
Model: "", // this will be set later
|
Messages: []types.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "You just need to output 'hi' next.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Model: "",
|
||||||
MaxTokens: 1,
|
MaxTokens: 1,
|
||||||
|
Stream: false,
|
||||||
}
|
}
|
||||||
testMessage := Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: "hi",
|
|
||||||
}
|
|
||||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
|
||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,220 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
|
||||||
|
|
||||||
type AIProxyLibraryRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Query string `json:"query"`
|
|
||||||
LibraryId string `json:"libraryId"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryError struct {
|
|
||||||
ErrCode int `json:"errCode"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryDocument struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryResponse struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Answer string `json:"answer"`
|
|
||||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
|
||||||
AIProxyLibraryError
|
|
||||||
}
|
|
||||||
|
|
||||||
type AIProxyLibraryStreamResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Finish bool `json:"finish"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
|
||||||
query := ""
|
|
||||||
if len(request.Messages) != 0 {
|
|
||||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
|
||||||
}
|
|
||||||
return &AIProxyLibraryRequest{
|
|
||||||
Model: request.Model,
|
|
||||||
Stream: request.Stream,
|
|
||||||
Query: query,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
|
||||||
if len(documents) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
content := "\n\n参考文档:\n"
|
|
||||||
for i, document := range documents {
|
|
||||||
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
|
||||||
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: content,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
return &ChatCompletionsStreamResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = response.Content
|
|
||||||
return &ChatCompletionsStreamResponse{
|
|
||||||
Id: common.GetUUID(),
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: response.Model,
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
var documents []AIProxyLibraryDocument
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
|
||||||
documents = AIProxyLibraryResponse.Documents
|
|
||||||
}
|
|
||||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
response := documentsAIProxyLibrary(documents)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var AIProxyLibraryResponse AIProxyLibraryResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: AIProxyLibraryResponse.Message,
|
|
||||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
|
||||||
Code: AIProxyLibraryResponse.ErrCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -1,329 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
|
||||||
|
|
||||||
type AliMessage struct {
|
|
||||||
User string `json:"user"`
|
|
||||||
Bot string `json:"bot"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliInput struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
History []AliMessage `json:"history"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliParameters struct {
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
Seed uint64 `json:"seed,omitempty"`
|
|
||||||
EnableSearch bool `json:"enable_search,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input AliInput `json:"input"`
|
|
||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
} `json:"input"`
|
|
||||||
Parameters *struct {
|
|
||||||
TextType string `json:"text_type,omitempty"`
|
|
||||||
} `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbedding struct {
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
TextIndex int `json:"text_index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliEmbeddingResponse struct {
|
|
||||||
Output struct {
|
|
||||||
Embeddings []AliEmbedding `json:"embeddings"`
|
|
||||||
} `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliUsage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliOutput struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AliChatResponse struct {
|
|
||||||
Output AliOutput `json:"output"`
|
|
||||||
Usage AliUsage `json:"usage"`
|
|
||||||
AliError
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|
||||||
messages := make([]AliMessage, 0, len(request.Messages))
|
|
||||||
prompt := ""
|
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
|
||||||
message := request.Messages[i]
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, AliMessage{
|
|
||||||
User: message.StringContent(),
|
|
||||||
Bot: "Okay",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
if i == len(request.Messages)-1 {
|
|
||||||
prompt = message.StringContent()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
messages = append(messages, AliMessage{
|
|
||||||
User: message.StringContent(),
|
|
||||||
Bot: request.Messages[i+1].StringContent(),
|
|
||||||
})
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &AliChatRequest{
|
|
||||||
Model: request.Model,
|
|
||||||
Input: AliInput{
|
|
||||||
Prompt: prompt,
|
|
||||||
History: messages,
|
|
||||||
},
|
|
||||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
|
||||||
// TopP: request.TopP,
|
|
||||||
// TopK: 50,
|
|
||||||
// //Seed: 0,
|
|
||||||
// //EnableSearch: false,
|
|
||||||
//},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
|
||||||
return &AliEmbeddingRequest{
|
|
||||||
Model: "text-embedding-v1",
|
|
||||||
Input: struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
}{
|
|
||||||
Texts: request.ParseInput(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var aliResponse AliEmbeddingResponse
|
|
||||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if aliResponse.Code != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: aliResponse.Message,
|
|
||||||
Type: aliResponse.Code,
|
|
||||||
Param: aliResponse.RequestId,
|
|
||||||
Code: aliResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
|
||||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
|
||||||
Object: "list",
|
|
||||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
|
||||||
Model: "text-embedding-v1",
|
|
||||||
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range response.Output.Embeddings {
|
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
|
||||||
Object: `embedding`,
|
|
||||||
Index: item.TextIndex,
|
|
||||||
Embedding: item.Embedding,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &openAIEmbeddingResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Output.Text,
|
|
||||||
},
|
|
||||||
FinishReason: response.Output.FinishReason,
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: response.RequestId,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
Usage: Usage{
|
|
||||||
PromptTokens: response.Usage.InputTokens,
|
|
||||||
CompletionTokens: response.Usage.OutputTokens,
|
|
||||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = aliResponse.Output.Text
|
|
||||||
if aliResponse.Output.FinishReason != "null" {
|
|
||||||
finishReason := aliResponse.Output.FinishReason
|
|
||||||
choice.FinishReason = &finishReason
|
|
||||||
}
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: aliResponse.RequestId,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "ernie-bot",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:5] != "data:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[5:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
lastResponseText := ""
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var aliResponse AliChatResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if aliResponse.Usage.OutputTokens != 0 {
|
|
||||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
|
||||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
|
||||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
|
||||||
}
|
|
||||||
response := streamResponseAli2OpenAI(&aliResponse)
|
|
||||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
|
||||||
lastResponseText = aliResponse.Output.Text
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var aliResponse AliChatResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if aliResponse.Code != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: aliResponse.Message,
|
|
||||||
Type: aliResponse.Code,
|
|
||||||
Param: aliResponse.RequestId,
|
|
||||||
Code: aliResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -1,183 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
audioModel := "whisper-1"
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
group := c.GetString("group")
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
|
|
||||||
var ttsRequest TextToSpeechRequest
|
|
||||||
if relayMode == RelayModeAudioSpeech {
|
|
||||||
// Read JSON
|
|
||||||
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
|
||||||
// Check if JSON is valid
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
audioModel = ttsRequest.Model
|
|
||||||
// Check if text is too long 4096
|
|
||||||
if len(ttsRequest.Input) > 4096 {
|
|
||||||
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
|
||||||
modelRatio := common.GetModelRatio(audioModel)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
quota := 0
|
|
||||||
// Check if user quota is enough
|
|
||||||
if relayMode == RelayModeAudioSpeech {
|
|
||||||
quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
|
|
||||||
if quota > userQuota {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if userQuota-preConsumedQuota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
if modelMapping != "" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[audioModel] != "" {
|
|
||||||
audioModel = modelMap[audioModel]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestBody := c.Request.Body
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
req.Header.Set("api-key", apiKey)
|
|
||||||
req.ContentLength = c.Request.ContentLength
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if relayMode == RelayModeAudioSpeech {
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
|
||||||
}(c.Request.Context())
|
|
||||||
} else {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
var whisperResponse WhisperResponse
|
|
||||||
err = json.Unmarshal(responseBody, &whisperResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
quota := countTokenText(whisperResponse.Text, audioModel)
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
|
||||||
}(c.Request.Context())
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,359 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
|
||||||
|
|
||||||
type BaiduTokenResponse struct {
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
|
||||||
Messages []BaiduMessage `json:"messages"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
UserId string `json:"user_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduError struct {
|
|
||||||
ErrorCode int `json:"error_code"`
|
|
||||||
ErrorMsg string `json:"error_msg"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Result string `json:"result"`
|
|
||||||
IsTruncated bool `json:"is_truncated"`
|
|
||||||
NeedClearHistory bool `json:"need_clear_history"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduChatStreamResponse struct {
|
|
||||||
BaiduChatResponse
|
|
||||||
SentenceId int `json:"sentence_id"`
|
|
||||||
IsEnd bool `json:"is_end"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingRequest struct {
|
|
||||||
Input []string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingData struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduEmbeddingResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Data []BaiduEmbeddingData `json:"data"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
BaiduError
|
|
||||||
}
|
|
||||||
|
|
||||||
type BaiduAccessToken struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
ErrorDescription string `json:"error_description,omitempty"`
|
|
||||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
|
||||||
ExpiresAt time.Time `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var baiduTokenStore sync.Map
|
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|
||||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, BaiduMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, BaiduMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages = append(messages, BaiduMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &BaiduChatRequest{
|
|
||||||
Messages: messages,
|
|
||||||
Stream: request.Stream,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Result,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: response.Id,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: response.Created,
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
Usage: response.Usage,
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = baiduResponse.Result
|
|
||||||
if baiduResponse.IsEnd {
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
}
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: baiduResponse.Id,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: baiduResponse.Created,
|
|
||||||
Model: "ernie-bot",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
|
||||||
return &BaiduEmbeddingRequest{
|
|
||||||
Input: request.ParseInput(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
|
||||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
|
||||||
Object: "list",
|
|
||||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
|
||||||
Model: "baidu-embedding",
|
|
||||||
Usage: response.Usage,
|
|
||||||
}
|
|
||||||
for _, item := range response.Data {
|
|
||||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
|
||||||
Object: item.Object,
|
|
||||||
Index: item.Index,
|
|
||||||
Embedding: item.Embedding,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &openAIEmbeddingResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[6:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var baiduResponse BaiduChatStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if baiduResponse.Usage.TotalTokens != 0 {
|
|
||||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
|
||||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
|
||||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
|
||||||
}
|
|
||||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var baiduResponse BaiduChatResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if baiduResponse.ErrorMsg != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: baiduResponse.ErrorMsg,
|
|
||||||
Type: "baidu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: baiduResponse.ErrorCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var baiduResponse BaiduEmbeddingResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if baiduResponse.ErrorMsg != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: baiduResponse.ErrorMsg,
|
|
||||||
Type: "baidu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: baiduResponse.ErrorCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaiduAccessToken(apiKey string) (string, error) {
|
|
||||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
|
||||||
var accessToken BaiduAccessToken
|
|
||||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
|
||||||
// soon this will expire
|
|
||||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
|
||||||
go func() {
|
|
||||||
_, _ = getBaiduAccessTokenHelper(apiKey)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return accessToken.AccessToken, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
accessToken, err := getBaiduAccessTokenHelper(apiKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if accessToken == nil {
|
|
||||||
return "", errors.New("getBaiduAccessToken return a nil token")
|
|
||||||
}
|
|
||||||
return (*accessToken).AccessToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
|
||||||
parts := strings.Split(apiKey, "|")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil, errors.New("invalid baidu apikey")
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
|
||||||
parts[0], parts[1]), nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
req.Header.Add("Accept", "application/json")
|
|
||||||
res, err := impatientHTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
var accessToken BaiduAccessToken
|
|
||||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if accessToken.Error != "" {
|
|
||||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
|
||||||
}
|
|
||||||
if accessToken.AccessToken == "" {
|
|
||||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
|
||||||
}
|
|
||||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
|
||||||
baiduTokenStore.Store(apiKey, accessToken)
|
|
||||||
return &accessToken, nil
|
|
||||||
}
|
|
127
controller/relay-chat.go
Normal file
127
controller/relay-chat.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func relayChatHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
||||||
|
|
||||||
|
// 获取请求参数
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
// consumeQuota := c.GetBool("consume_quota")
|
||||||
|
group := c.GetString("group")
|
||||||
|
|
||||||
|
// 获取 Provider
|
||||||
|
chatProvider := GetChatProvider(channelType, c)
|
||||||
|
if chatProvider == nil {
|
||||||
|
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求体
|
||||||
|
var chatRequest types.ChatCompletionRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查模型映射
|
||||||
|
isModelMapped := false
|
||||||
|
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||||
|
chatRequest.Model = modelMap[chatRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始计算Tokens
|
||||||
|
var promptTokens int
|
||||||
|
promptTokens = common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
||||||
|
|
||||||
|
// 计算预付费配额
|
||||||
|
quotaInfo := &QuotaInfo{
|
||||||
|
modelName: chatRequest.Model,
|
||||||
|
promptTokens: promptTokens,
|
||||||
|
userId: userId,
|
||||||
|
channelId: channelId,
|
||||||
|
tokenId: tokenId,
|
||||||
|
}
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return quota_err
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openAIErrorWithStatusCode := chatProvider.ChatCompleteResponse(&chatRequest, isModelMapped, promptTokens)
|
||||||
|
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
if quotaInfo.preConsumedQuota != 0 {
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
// return pre-consumed quota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
}
|
||||||
|
return openAIErrorWithStatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}(c.Request.Context())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetChatProvider(channelType int, c *gin.Context) providers.ChatProviderAction {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeOpenAI:
|
||||||
|
return providers.CreateOpenAIProvider(c, "")
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
return providers.CreateAzureProvider(c)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
return providers.CreateAliAIProvider(c)
|
||||||
|
case common.ChannelTypeTencent:
|
||||||
|
return providers.CreateTencentProvider(c)
|
||||||
|
case common.ChannelTypeBaidu:
|
||||||
|
return providers.CreateBaiduProvider(c)
|
||||||
|
case common.ChannelTypeAnthropic:
|
||||||
|
return providers.CreateClaudeProvider(c)
|
||||||
|
case common.ChannelTypePaLM:
|
||||||
|
return providers.CreatePalmProvider(c)
|
||||||
|
case common.ChannelTypeZhipu:
|
||||||
|
return providers.CreateZhipuProvider(c)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
return providers.CreateXunfeiProvider(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
if c.GetString("base_url") != "" {
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
if baseURL != "" {
|
||||||
|
return providers.CreateOpenAIProvider(c, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,220 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
|
||||||
UserId string `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeResponse struct {
|
|
||||||
Completion string `json:"completion"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Error ClaudeError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
|
||||||
switch reason {
|
|
||||||
case "stop_sequence":
|
|
||||||
return "stop"
|
|
||||||
case "max_tokens":
|
|
||||||
return "length"
|
|
||||||
default:
|
|
||||||
return reason
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
|
||||||
claudeRequest := ClaudeRequest{
|
|
||||||
Model: textRequest.Model,
|
|
||||||
Prompt: "",
|
|
||||||
MaxTokensToSample: textRequest.MaxTokens,
|
|
||||||
StopSequences: nil,
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
TopP: textRequest.TopP,
|
|
||||||
Stream: textRequest.Stream,
|
|
||||||
}
|
|
||||||
if claudeRequest.MaxTokensToSample == 0 {
|
|
||||||
claudeRequest.MaxTokensToSample = 1000000
|
|
||||||
}
|
|
||||||
prompt := ""
|
|
||||||
for _, message := range textRequest.Messages {
|
|
||||||
if message.Role == "user" {
|
|
||||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
|
||||||
} else if message.Role == "assistant" {
|
|
||||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
|
||||||
} else if message.Role == "system" {
|
|
||||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prompt += "\n\nAssistant:"
|
|
||||||
claudeRequest.Prompt = prompt
|
|
||||||
return &claudeRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = claudeResponse.Completion
|
|
||||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
|
||||||
if finishReason != "null" {
|
|
||||||
choice.FinishReason = &finishReason
|
|
||||||
}
|
|
||||||
var response ChatCompletionsStreamResponse
|
|
||||||
response.Object = "chat.completion.chunk"
|
|
||||||
response.Model = claudeResponse.Model
|
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
|
||||||
Name: nil,
|
|
||||||
},
|
|
||||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
||||||
createdTime := common.GetTimestamp()
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
|
||||||
return i + 4, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if !strings.HasPrefix(data, "event: completion") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
var claudeResponse ClaudeResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
responseText += claudeResponse.Completion
|
|
||||||
response := streamResponseClaude2OpenAI(&claudeResponse)
|
|
||||||
response.Id = responseId
|
|
||||||
response.Created = createdTime
|
|
||||||
jsonStr, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var claudeResponse ClaudeResponse
|
|
||||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if claudeResponse.Error.Type != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: claudeResponse.Error.Message,
|
|
||||||
Type: claudeResponse.Error.Type,
|
|
||||||
Param: "",
|
|
||||||
Code: claudeResponse.Error.Type,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
|
||||||
completionTokens := countTokenText(claudeResponse.Completion, model)
|
|
||||||
usage := Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
fullTextResponse.Usage = usage
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
113
controller/relay-completion.go
Normal file
113
controller/relay-completion.go
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func relayCompletionHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
||||||
|
|
||||||
|
// 获取请求参数
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
// consumeQuota := c.GetBool("consume_quota")
|
||||||
|
group := c.GetString("group")
|
||||||
|
|
||||||
|
// 获取 Provider
|
||||||
|
completionProvider := GetCompletionProvider(channelType, c)
|
||||||
|
if completionProvider == nil {
|
||||||
|
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求体
|
||||||
|
var completionRequest types.CompletionRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查模型映射
|
||||||
|
isModelMapped := false
|
||||||
|
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||||
|
completionRequest.Model = modelMap[completionRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始计算Tokens
|
||||||
|
var promptTokens int
|
||||||
|
promptTokens = common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
||||||
|
|
||||||
|
// 计算预付费配额
|
||||||
|
quotaInfo := &QuotaInfo{
|
||||||
|
modelName: completionRequest.Model,
|
||||||
|
promptTokens: promptTokens,
|
||||||
|
userId: userId,
|
||||||
|
channelId: channelId,
|
||||||
|
tokenId: tokenId,
|
||||||
|
}
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return quota_err
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openAIErrorWithStatusCode := completionProvider.CompleteResponse(&completionRequest, isModelMapped, promptTokens)
|
||||||
|
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
if quotaInfo.preConsumedQuota != 0 {
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
// return pre-consumed quota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
}
|
||||||
|
return openAIErrorWithStatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}(c.Request.Context())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCompletionProvider(channelType int, c *gin.Context) providers.CompletionProviderAction {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeOpenAI:
|
||||||
|
return providers.CreateOpenAIProvider(c, "")
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
return providers.CreateAzureProvider(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
if c.GetString("base_url") != "" {
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
if baseURL != "" {
|
||||||
|
return providers.CreateOpenAIProvider(c, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
117
controller/relay-embeddings.go
Normal file
117
controller/relay-embeddings.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func relayEmbeddingsHelper(c *gin.Context) *types.OpenAIErrorWithStatusCode {
|
||||||
|
|
||||||
|
// 获取请求参数
|
||||||
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
// consumeQuota := c.GetBool("consume_quota")
|
||||||
|
group := c.GetString("group")
|
||||||
|
|
||||||
|
// 获取 Provider
|
||||||
|
embeddingsProvider := GetEmbeddingsProvider(channelType, c)
|
||||||
|
if embeddingsProvider == nil {
|
||||||
|
return types.ErrorWrapper(errors.New("API not implemented"), "api_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求体
|
||||||
|
var embeddingsRequest types.EmbeddingRequest
|
||||||
|
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查模型映射
|
||||||
|
isModelMapped := false
|
||||||
|
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||||
|
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始计算Tokens
|
||||||
|
var promptTokens int
|
||||||
|
promptTokens = common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
||||||
|
|
||||||
|
// 计算预付费配额
|
||||||
|
quotaInfo := &QuotaInfo{
|
||||||
|
modelName: embeddingsRequest.Model,
|
||||||
|
promptTokens: promptTokens,
|
||||||
|
userId: userId,
|
||||||
|
channelId: channelId,
|
||||||
|
tokenId: tokenId,
|
||||||
|
}
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return quota_err
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openAIErrorWithStatusCode := embeddingsProvider.EmbeddingsResponse(&embeddingsRequest, isModelMapped, promptTokens)
|
||||||
|
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
if quotaInfo.preConsumedQuota != 0 {
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
// return pre-consumed quota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
}
|
||||||
|
return openAIErrorWithStatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}(c.Request.Context())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEmbeddingsProvider(channelType int, c *gin.Context) providers.EmbeddingsProviderAction {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeOpenAI:
|
||||||
|
return providers.CreateOpenAIProvider(c, "")
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
return providers.CreateAzureProvider(c)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
return providers.CreateAliAIProvider(c)
|
||||||
|
case common.ChannelTypeBaidu:
|
||||||
|
return providers.CreateBaiduProvider(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
|
if c.GetString("base_url") != "" {
|
||||||
|
baseURL = c.GetString("base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
if baseURL != "" {
|
||||||
|
return providers.CreateOpenAIProvider(c, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,206 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isWithinRange(element string, value int) bool {
|
|
||||||
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
min := common.DalleGenerationImageAmounts[element][0]
|
|
||||||
max := common.DalleGenerationImageAmounts[element][1]
|
|
||||||
|
|
||||||
return value >= min && value <= max
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
imageModel := "dall-e-2"
|
|
||||||
imageSize := "1024x1024"
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
|
||||||
|
|
||||||
var imageRequest ImageRequest
|
|
||||||
if consumeQuota {
|
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size validation
|
|
||||||
if imageRequest.Size != "" {
|
|
||||||
imageSize = imageRequest.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model validation
|
|
||||||
if imageRequest.Model != "" {
|
|
||||||
imageModel = imageRequest.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
|
|
||||||
|
|
||||||
// Check if model is supported
|
|
||||||
if hasValidSize {
|
|
||||||
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
|
|
||||||
if imageSize == "1024x1024" {
|
|
||||||
imageCostRatio *= 2
|
|
||||||
} else {
|
|
||||||
imageCostRatio *= 1.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prompt validation
|
|
||||||
if imageRequest.Prompt == "" {
|
|
||||||
return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check prompt length
|
|
||||||
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
|
|
||||||
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Number of generated images validation
|
|
||||||
if isWithinRange(imageModel, imageRequest.N) == false {
|
|
||||||
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[imageModel] != "" {
|
|
||||||
imageModel = modelMap[imageModel]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
var requestBody io.Reader
|
|
||||||
if isModelMapped {
|
|
||||||
jsonStr, err := json.Marshal(imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
|
|
||||||
modelRatio := common.GetModelRatio(imageModel)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
|
|
||||||
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
var textResponse ImageResponse
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
if consumeQuota {
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
if consumeQuota {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,144 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dataChan <- data
|
|
||||||
data = data[6:]
|
|
||||||
if !strings.HasPrefix(data, "[DONE]") {
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
var streamResponse ChatCompletionsStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
continue // just ignore the error
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseText += choice.Delta.Content
|
|
||||||
}
|
|
||||||
case RelayModeCompletions:
|
|
||||||
var streamResponse CompletionsStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
|
||||||
responseText += choice.Text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
if strings.HasPrefix(data, "data: [DONE]") {
|
|
||||||
data = data[:12]
|
|
||||||
}
|
|
||||||
// some implementations may add \r at the end of data
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
|
||||||
c.Render(-1, common.CustomEvent{Data: data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var textResponse TextResponse
|
|
||||||
if consumeQuota {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if textResponse.Error.Type != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: textResponse.Error,
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err := io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if textResponse.Usage.TotalTokens == 0 {
|
|
||||||
completionTokens := 0
|
|
||||||
for _, choice := range textResponse.Choices {
|
|
||||||
completionTokens += countTokenText(choice.Message.StringContent(), model)
|
|
||||||
}
|
|
||||||
textResponse.Usage = Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, &textResponse.Usage
|
|
||||||
}
|
|
@ -1,205 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
|
||||||
|
|
||||||
type PaLMChatMessage struct {
|
|
||||||
Author string `json:"author"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMFilter struct {
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMPrompt struct {
|
|
||||||
Messages []PaLMChatMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatRequest struct {
|
|
||||||
Prompt PaLMPrompt `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK int `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PaLMChatResponse struct {
|
|
||||||
Candidates []PaLMChatMessage `json:"candidates"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
Filters []PaLMFilter `json:"filters"`
|
|
||||||
Error PaLMError `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
|
||||||
palmRequest := PaLMChatRequest{
|
|
||||||
Prompt: PaLMPrompt{
|
|
||||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
|
||||||
},
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
CandidateCount: textRequest.N,
|
|
||||||
TopP: textRequest.TopP,
|
|
||||||
TopK: textRequest.MaxTokens,
|
|
||||||
}
|
|
||||||
for _, message := range textRequest.Messages {
|
|
||||||
palmMessage := PaLMChatMessage{
|
|
||||||
Content: message.StringContent(),
|
|
||||||
}
|
|
||||||
if message.Role == "user" {
|
|
||||||
palmMessage.Author = "0"
|
|
||||||
} else {
|
|
||||||
palmMessage.Author = "1"
|
|
||||||
}
|
|
||||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
|
||||||
}
|
|
||||||
return &palmRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
|
||||||
}
|
|
||||||
for i, candidate := range response.Candidates {
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: i,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: candidate.Content,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
if len(palmResponse.Candidates) > 0 {
|
|
||||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
|
||||||
}
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
var response ChatCompletionsStreamResponse
|
|
||||||
response.Object = "chat.completion.chunk"
|
|
||||||
response.Model = "palm2"
|
|
||||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
|
||||||
responseText := ""
|
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
||||||
createdTime := common.GetTimestamp()
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error reading stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error closing stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var palmResponse PaLMChatResponse
|
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
|
||||||
fullTextResponse.Id = responseId
|
|
||||||
fullTextResponse.Created = createdTime
|
|
||||||
if len(palmResponse.Candidates) > 0 {
|
|
||||||
responseText = palmResponse.Candidates[0].Content
|
|
||||||
}
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
stopChan <- true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dataChan <- string(jsonResponse)
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
|
||||||
}
|
|
||||||
|
|
||||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var palmResponse PaLMChatResponse
|
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: palmResponse.Error.Message,
|
|
||||||
Type: palmResponse.Error.Status,
|
|
||||||
Param: "",
|
|
||||||
Code: palmResponse.Error.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
|
||||||
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
|
||||||
usage := Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
CompletionTokens: completionTokens,
|
|
||||||
TotalTokens: promptTokens + completionTokens,
|
|
||||||
}
|
|
||||||
fullTextResponse.Usage = usage
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
@ -1,649 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
APITypeOpenAI = iota
|
|
||||||
APITypeClaude
|
|
||||||
APITypePaLM
|
|
||||||
APITypeBaidu
|
|
||||||
APITypeZhipu
|
|
||||||
APITypeAli
|
|
||||||
APITypeXunfei
|
|
||||||
APITypeAIProxyLibrary
|
|
||||||
APITypeTencent
|
|
||||||
)
|
|
||||||
|
|
||||||
var httpClient *http.Client
|
|
||||||
var impatientHTTPClient *http.Client
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if common.RelayTimeout == 0 {
|
|
||||||
httpClient = &http.Client{}
|
|
||||||
} else {
|
|
||||||
httpClient = &http.Client{
|
|
||||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impatientHTTPClient = &http.Client{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
||||||
channelType := c.GetInt("channel")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
|
||||||
var textRequest GeneralOpenAIRequest
|
|
||||||
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
|
||||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
|
||||||
textRequest.Model = "text-moderation-latest"
|
|
||||||
}
|
|
||||||
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
|
|
||||||
textRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
// request validation
|
|
||||||
if textRequest.Model == "" {
|
|
||||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeCompletions:
|
|
||||||
if textRequest.Prompt == "" {
|
|
||||||
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
|
||||||
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
case RelayModeModerations:
|
|
||||||
if textRequest.Input == "" {
|
|
||||||
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
case RelayModeEdits:
|
|
||||||
if textRequest.Instruction == "" {
|
|
||||||
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// map model name
|
|
||||||
modelMapping := c.GetString("model_mapping")
|
|
||||||
isModelMapped := false
|
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[textRequest.Model] != "" {
|
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
apiType := APITypeOpenAI
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeAnthropic:
|
|
||||||
apiType = APITypeClaude
|
|
||||||
case common.ChannelTypeBaidu:
|
|
||||||
apiType = APITypeBaidu
|
|
||||||
case common.ChannelTypePaLM:
|
|
||||||
apiType = APITypePaLM
|
|
||||||
case common.ChannelTypeZhipu:
|
|
||||||
apiType = APITypeZhipu
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
apiType = APITypeAli
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
apiType = APITypeXunfei
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
apiType = APITypeAIProxyLibrary
|
|
||||||
case common.ChannelTypeTencent:
|
|
||||||
apiType = APITypeTencent
|
|
||||||
}
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
if c.GetString("base_url") != "" {
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
|
||||||
query := c.Request.URL.Query()
|
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
requestURL := strings.Split(requestURL, "?")[0]
|
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
|
||||||
model_ := textRequest.Model
|
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
|
||||||
model_ = strings.TrimSuffix(model_, "-0613")
|
|
||||||
|
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
|
||||||
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
|
||||||
if baseURL != "" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
|
||||||
}
|
|
||||||
case APITypeBaidu:
|
|
||||||
switch textRequest.Model {
|
|
||||||
case "ERNIE-Bot":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
|
||||||
case "ERNIE-Bot-turbo":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
||||||
case "ERNIE-Bot-4":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
|
||||||
case "BLOOMZ-7B":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
|
||||||
case "Embedding-V1":
|
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
var err error
|
|
||||||
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
|
||||||
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
fullRequestURL += "?access_token=" + apiKey
|
|
||||||
case APITypePaLM:
|
|
||||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
|
||||||
if baseURL != "" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
fullRequestURL += "?key=" + apiKey
|
|
||||||
case APITypeZhipu:
|
|
||||||
method := "invoke"
|
|
||||||
if textRequest.Stream {
|
|
||||||
method = "sse-invoke"
|
|
||||||
}
|
|
||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
|
||||||
case APITypeAli:
|
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|
||||||
if relayMode == RelayModeEmbeddings {
|
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
|
||||||
}
|
|
||||||
var promptTokens int
|
|
||||||
var completionTokens int
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeChatCompletions:
|
|
||||||
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
||||||
case RelayModeCompletions:
|
|
||||||
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
|
||||||
case RelayModeModerations:
|
|
||||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
|
||||||
}
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
|
||||||
if textRequest.MaxTokens != 0 {
|
|
||||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
||||||
}
|
|
||||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota-preConsumedQuota < 0 {
|
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
|
||||||
}
|
|
||||||
if consumeQuota && preConsumedQuota > 0 {
|
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var requestBody io.Reader
|
|
||||||
if isModelMapped {
|
|
||||||
jsonStr, err := json.Marshal(textRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
} else {
|
|
||||||
requestBody = c.Request.Body
|
|
||||||
}
|
|
||||||
switch apiType {
|
|
||||||
case APITypeClaude:
|
|
||||||
claudeRequest := requestOpenAI2Claude(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(claudeRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeBaidu:
|
|
||||||
var jsonData []byte
|
|
||||||
var err error
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
|
||||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
|
||||||
default:
|
|
||||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
|
||||||
jsonData, err = json.Marshal(baiduRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
|
||||||
case APITypePaLM:
|
|
||||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(palmRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeZhipu:
|
|
||||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
|
||||||
jsonStr, err := json.Marshal(zhipuRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeAli:
|
|
||||||
var jsonStr []byte
|
|
||||||
var err error
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
|
||||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
|
||||||
default:
|
|
||||||
aliRequest := requestOpenAI2Ali(textRequest)
|
|
||||||
jsonStr, err = json.Marshal(aliRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeTencent:
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
tencentRequest := requestOpenAI2Tencent(textRequest)
|
|
||||||
tencentRequest.AppId = appId
|
|
||||||
tencentRequest.SecretId = secretId
|
|
||||||
jsonStr, err := json.Marshal(tencentRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
sign := getTencentSign(*tencentRequest, secretKey)
|
|
||||||
c.Request.Header.Set("Authorization", sign)
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
|
||||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
|
||||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
var req *http.Request
|
|
||||||
var resp *http.Response
|
|
||||||
isStream := textRequest.Stream
|
|
||||||
|
|
||||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
|
||||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
apiKey := c.Request.Header.Get("Authorization")
|
|
||||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if channelType == common.ChannelTypeAzure {
|
|
||||||
req.Header.Set("api-key", apiKey)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
||||||
if channelType == common.ChannelTypeOpenRouter {
|
|
||||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
|
||||||
req.Header.Set("X-Title", "One API")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
|
||||||
if anthropicVersion == "" {
|
|
||||||
anthropicVersion = "2023-06-01"
|
|
||||||
}
|
|
||||||
req.Header.Set("anthropic-version", anthropicVersion)
|
|
||||||
case APITypeZhipu:
|
|
||||||
token := getZhipuToken(apiKey)
|
|
||||||
req.Header.Set("Authorization", token)
|
|
||||||
case APITypeAli:
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
if textRequest.Stream {
|
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
req.Header.Set("Authorization", apiKey)
|
|
||||||
case APITypePaLM:
|
|
||||||
// do not set Authorization header
|
|
||||||
default:
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
||||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
|
||||||
}
|
|
||||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
|
||||||
resp, err = httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if preConsumedQuota != 0 {
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
// return pre-consumed quota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
|
||||||
}
|
|
||||||
}(c.Request.Context())
|
|
||||||
}
|
|
||||||
return relayErrorHandler(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var textResponse TextResponse
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
// c.Writer.Flush()
|
|
||||||
go func() {
|
|
||||||
if consumeQuota {
|
|
||||||
quota := 0
|
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
|
||||||
completionTokens = textResponse.Usage.CompletionTokens
|
|
||||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
|
||||||
if ratio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
totalTokens := promptTokens + completionTokens
|
|
||||||
if totalTokens == 0 {
|
|
||||||
// in this case, must be some error happened
|
|
||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
|
||||||
quota = 0
|
|
||||||
}
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
switch apiType {
|
|
||||||
case APITypeOpenAI:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := openaiStreamHandler(c, resp, relayMode)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeClaude:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := claudeStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeBaidu:
|
|
||||||
if isStream {
|
|
||||||
err, usage := baiduStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
err, usage = baiduEmbeddingHandler(c, resp)
|
|
||||||
default:
|
|
||||||
err, usage = baiduHandler(c, resp)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypePaLM:
|
|
||||||
if textRequest.Stream { // PaLM2 API does not support stream
|
|
||||||
err, responseText := palmStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeZhipu:
|
|
||||||
if isStream {
|
|
||||||
err, usage := zhipuStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
// zhipu's API does not return prompt tokens & completion tokens
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := zhipuHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
// zhipu's API does not return prompt tokens & completion tokens
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeAli:
|
|
||||||
if isStream {
|
|
||||||
err, usage := aliStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeEmbeddings:
|
|
||||||
err, usage = aliEmbeddingHandler(c, resp)
|
|
||||||
default:
|
|
||||||
err, usage = aliHandler(c, resp)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeXunfei:
|
|
||||||
auth := c.Request.Header.Get("Authorization")
|
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
|
||||||
splits := strings.Split(auth, "|")
|
|
||||||
if len(splits) != 3 {
|
|
||||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
var usage *Usage
|
|
||||||
if isStream {
|
|
||||||
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
} else {
|
|
||||||
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case APITypeAIProxyLibrary:
|
|
||||||
if isStream {
|
|
||||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := aiProxyLibraryHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case APITypeTencent:
|
|
||||||
if isStream {
|
|
||||||
err, responseText := tencentStreamHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
textResponse.Usage.PromptTokens = promptTokens
|
|
||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
err, usage := tencentHandler(c, resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
@ -3,133 +3,16 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"one-api/types"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var stopFinishReason = "stop"
|
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||||
|
|
||||||
// tokenEncoderMap won't grow after initialization
|
|
||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
||||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
|
||||||
|
|
||||||
func InitTokenEncoders() {
|
|
||||||
common.SysLog("initializing token encoders")
|
|
||||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
|
||||||
}
|
|
||||||
defaultTokenEncoder = gpt35TokenEncoder
|
|
||||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
|
||||||
if err != nil {
|
|
||||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
|
||||||
}
|
|
||||||
for model, _ := range common.ModelRatio {
|
|
||||||
if strings.HasPrefix(model, "gpt-3.5") {
|
|
||||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
|
||||||
} else if strings.HasPrefix(model, "gpt-4") {
|
|
||||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
|
||||||
} else {
|
|
||||||
tokenEncoderMap[model] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
common.SysLog("token encoders initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
||||||
tokenEncoder, ok := tokenEncoderMap[model]
|
|
||||||
if ok && tokenEncoder != nil {
|
|
||||||
return tokenEncoder
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
||||||
tokenEncoder = defaultTokenEncoder
|
|
||||||
}
|
|
||||||
tokenEncoderMap[model] = tokenEncoder
|
|
||||||
return tokenEncoder
|
|
||||||
}
|
|
||||||
return defaultTokenEncoder
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
||||||
if common.ApproximateTokenEnabled {
|
|
||||||
return int(float64(len(text)) * 0.38)
|
|
||||||
}
|
|
||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenMessages(messages []Message, model string) int {
|
|
||||||
tokenEncoder := getTokenEncoder(model)
|
|
||||||
// Reference:
|
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
|
||||||
//
|
|
||||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
||||||
var tokensPerMessage int
|
|
||||||
var tokensPerName int
|
|
||||||
if model == "gpt-3.5-turbo-0301" {
|
|
||||||
tokensPerMessage = 4
|
|
||||||
tokensPerName = -1 // If there's a name, the role is omitted
|
|
||||||
} else {
|
|
||||||
tokensPerMessage = 3
|
|
||||||
tokensPerName = 1
|
|
||||||
}
|
|
||||||
tokenNum := 0
|
|
||||||
for _, message := range messages {
|
|
||||||
tokenNum += tokensPerMessage
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.StringContent())
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
|
||||||
if message.Name != nil {
|
|
||||||
tokenNum += tokensPerName
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
|
||||||
return tokenNum
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenInput(input any, model string) int {
|
|
||||||
switch input.(type) {
|
|
||||||
case string:
|
|
||||||
return countTokenText(input.(string), model)
|
|
||||||
case []string:
|
|
||||||
text := ""
|
|
||||||
for _, s := range input.([]string) {
|
|
||||||
text += s
|
|
||||||
}
|
|
||||||
return countTokenText(text, model)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenText(text string, model string) int {
|
|
||||||
tokenEncoder := getTokenEncoder(model)
|
|
||||||
return getTokenNum(tokenEncoder, text)
|
|
||||||
}
|
|
||||||
|
|
||||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
|
|
||||||
openAIError := OpenAIError{
|
|
||||||
Message: err.Error(),
|
|
||||||
Type: "one_api_error",
|
|
||||||
Code: code,
|
|
||||||
}
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: openAIError,
|
|
||||||
StatusCode: statusCode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
|
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -145,56 +28,6 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func setEventStreamHeaders(c *gin.Context) {
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
|
|
||||||
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
|
||||||
Type: "upstream_error",
|
|
||||||
Code: "bad_response_status_code",
|
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var textResponse TextResponse
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
||||||
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
switch channelType {
|
|
||||||
case common.ChannelTypeOpenAI:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fullRequestURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -211,3 +44,110 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
|
|||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseModelMapping(modelMapping string) (map[string]string, error) {
|
||||||
|
if modelMapping == "" || modelMapping == "{}" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
modelMap := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return modelMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type QuotaInfo struct {
|
||||||
|
modelName string
|
||||||
|
promptTokens int
|
||||||
|
preConsumedTokens int
|
||||||
|
modelRatio float64
|
||||||
|
groupRatio float64
|
||||||
|
ratio float64
|
||||||
|
preConsumedQuota int
|
||||||
|
userId int
|
||||||
|
channelId int
|
||||||
|
tokenId int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QuotaInfo) initQuotaInfo(groupName string) {
|
||||||
|
modelRatio := common.GetModelRatio(q.modelName)
|
||||||
|
groupRatio := common.GetGroupRatio(groupName)
|
||||||
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
|
ratio := modelRatio * groupRatio
|
||||||
|
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
|
||||||
|
|
||||||
|
q.preConsumedTokens = preConsumedTokens
|
||||||
|
q.modelRatio = modelRatio
|
||||||
|
q.groupRatio = groupRatio
|
||||||
|
q.ratio = ratio
|
||||||
|
q.preConsumedQuota = preConsumedQuota
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
|
||||||
|
userQuota, err := model.CacheGetUserQuota(q.userId)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userQuota < q.preConsumedQuota {
|
||||||
|
return types.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userQuota > 100*q.preConsumedQuota {
|
||||||
|
// in this case, we do not pre-consume quota
|
||||||
|
// because the user has enough quota
|
||||||
|
q.preConsumedQuota = 0
|
||||||
|
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.preConsumedQuota > 0 {
|
||||||
|
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
|
||||||
|
quota := 0
|
||||||
|
completionRatio := common.GetCompletionRatio(q.modelName)
|
||||||
|
promptTokens := usage.PromptTokens
|
||||||
|
completionTokens := usage.CompletionTokens
|
||||||
|
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
|
||||||
|
if q.ratio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
totalTokens := promptTokens + completionTokens
|
||||||
|
if totalTokens == 0 {
|
||||||
|
// in this case, must be some error happened
|
||||||
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
|
quota = 0
|
||||||
|
}
|
||||||
|
quotaDelta := quota - q.preConsumedQuota
|
||||||
|
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(q.userId)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
|
||||||
|
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(q.channelId, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,301 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
|
||||||
// chatglm_std, chatglm_lite
|
|
||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
|
||||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
|
||||||
|
|
||||||
type ZhipuMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuRequest struct {
|
|
||||||
Prompt []ZhipuMessage `json:"prompt"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
|
||||||
RequestId string `json:"request_id,omitempty"`
|
|
||||||
Incremental bool `json:"incremental,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponseData struct {
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Choices []ZhipuMessage `json:"choices"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuResponse struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Data ZhipuResponseData `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZhipuStreamMetaResponse struct {
|
|
||||||
RequestId string `json:"request_id"`
|
|
||||||
TaskId string `json:"task_id"`
|
|
||||||
TaskStatus string `json:"task_status"`
|
|
||||||
Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type zhipuTokenData struct {
|
|
||||||
Token string
|
|
||||||
ExpiryTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
var zhipuTokens sync.Map
|
|
||||||
var expSeconds int64 = 24 * 3600
|
|
||||||
|
|
||||||
func getZhipuToken(apikey string) string {
|
|
||||||
data, ok := zhipuTokens.Load(apikey)
|
|
||||||
if ok {
|
|
||||||
tokenData := data.(zhipuTokenData)
|
|
||||||
if time.Now().Before(tokenData.ExpiryTime) {
|
|
||||||
return tokenData.Token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
split := strings.Split(apikey, ".")
|
|
||||||
if len(split) != 2 {
|
|
||||||
common.SysError("invalid zhipu key: " + apikey)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
id := split[0]
|
|
||||||
secret := split[1]
|
|
||||||
|
|
||||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
|
||||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
|
||||||
|
|
||||||
timestamp := time.Now().UnixNano() / 1e6
|
|
||||||
|
|
||||||
payload := jwt.MapClaims{
|
|
||||||
"api_key": id,
|
|
||||||
"exp": expMillis,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
|
||||||
|
|
||||||
token.Header["alg"] = "HS256"
|
|
||||||
token.Header["sign_type"] = "SIGN"
|
|
||||||
|
|
||||||
tokenString, err := token.SignedString([]byte(secret))
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
zhipuTokens.Store(apikey, zhipuTokenData{
|
|
||||||
Token: tokenString,
|
|
||||||
ExpiryTime: expiryTime,
|
|
||||||
})
|
|
||||||
|
|
||||||
return tokenString
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
|
||||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
|
||||||
messages = append(messages, ZhipuMessage{
|
|
||||||
Role: "system",
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
messages = append(messages, ZhipuMessage{
|
|
||||||
Role: "user",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages = append(messages, ZhipuMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &ZhipuRequest{
|
|
||||||
Prompt: messages,
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
Incremental: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Id: response.Data.TaskId,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
|
||||||
Usage: response.Data.Usage,
|
|
||||||
}
|
|
||||||
for i, choice := range response.Data.Choices {
|
|
||||||
openaiChoice := OpenAITextResponseChoice{
|
|
||||||
Index: i,
|
|
||||||
Message: Message{
|
|
||||||
Role: choice.Role,
|
|
||||||
Content: strings.Trim(choice.Content, "\""),
|
|
||||||
},
|
|
||||||
FinishReason: "",
|
|
||||||
}
|
|
||||||
if i == len(response.Data.Choices)-1 {
|
|
||||||
openaiChoice.FinishReason = "stop"
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = zhipuResponse
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "chatglm",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = ""
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Id: zhipuResponse.RequestId,
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "chatglm",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response, &zhipuResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var usage *Usage
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
|
||||||
return i + 2, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
metaChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
lines := strings.Split(data, "\n")
|
|
||||||
for i, line := range lines {
|
|
||||||
if len(line) < 5 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if line[:5] == "data:" {
|
|
||||||
dataChan <- line[5:]
|
|
||||||
if i != len(lines)-1 {
|
|
||||||
dataChan <- "\n"
|
|
||||||
}
|
|
||||||
} else if line[:5] == "meta:" {
|
|
||||||
metaChan <- line[5:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
response := streamResponseZhipu2OpenAI(data)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case data := <-metaChan:
|
|
||||||
var zhipuResponse ZhipuStreamMetaResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
usage = zhipuUsage
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var zhipuResponse ZhipuResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if !zhipuResponse.Success {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: zhipuResponse.Msg,
|
|
||||||
Type: "zhipu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: zhipuResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -234,41 +235,46 @@ type CompletionsStreamResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := RelayModeUnknown
|
var err *types.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
|
// relayMode := RelayModeUnknown
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||||
relayMode = RelayModeChatCompletions
|
err = relayChatHelper(c)
|
||||||
|
// relayMode = RelayModeChatCompletions
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||||
relayMode = RelayModeCompletions
|
err = relayCompletionHelper(c)
|
||||||
|
// relayMode = RelayModeCompletions
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||||
relayMode = RelayModeEmbeddings
|
err = relayEmbeddingsHelper(c)
|
||||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
relayMode = RelayModeEmbeddings
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
relayMode = RelayModeModerations
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
|
||||||
relayMode = RelayModeImagesGenerations
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
|
||||||
relayMode = RelayModeEdits
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
|
||||||
relayMode = RelayModeAudioSpeech
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
|
||||||
relayMode = RelayModeAudioTranscription
|
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
|
||||||
relayMode = RelayModeAudioTranslation
|
|
||||||
}
|
|
||||||
var err *OpenAIErrorWithStatusCode
|
|
||||||
switch relayMode {
|
|
||||||
case RelayModeImagesGenerations:
|
|
||||||
err = relayImageHelper(c, relayMode)
|
|
||||||
case RelayModeAudioSpeech:
|
|
||||||
fallthrough
|
|
||||||
case RelayModeAudioTranslation:
|
|
||||||
fallthrough
|
|
||||||
case RelayModeAudioTranscription:
|
|
||||||
err = relayAudioHelper(c, relayMode)
|
|
||||||
default:
|
|
||||||
err = relayTextHelper(c, relayMode)
|
|
||||||
}
|
}
|
||||||
|
// relayMode = RelayModeEmbeddings
|
||||||
|
// } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
|
// relayMode = RelayModeEmbeddings
|
||||||
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
|
// relayMode = RelayModeModerations
|
||||||
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
|
// relayMode = RelayModeImagesGenerations
|
||||||
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||||
|
// relayMode = RelayModeEdits
|
||||||
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||||
|
// relayMode = RelayModeAudioSpeech
|
||||||
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
|
// relayMode = RelayModeAudioTranscription
|
||||||
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
|
// relayMode = RelayModeAudioTranslation
|
||||||
|
// }
|
||||||
|
// switch relayMode {
|
||||||
|
// case RelayModeImagesGenerations:
|
||||||
|
// err = relayImageHelper(c, relayMode)
|
||||||
|
// case RelayModeAudioSpeech:
|
||||||
|
// fallthrough
|
||||||
|
// case RelayModeAudioTranslation:
|
||||||
|
// fallthrough
|
||||||
|
// case RelayModeAudioTranscription:
|
||||||
|
// err = relayAudioHelper(c, relayMode)
|
||||||
|
// default:
|
||||||
|
// err = relayTextHelper(c, relayMode)
|
||||||
|
// }
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
retryTimesStr := c.Query("retry")
|
retryTimesStr := c.Query("retry")
|
||||||
|
9
main.go
9
main.go
@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-contrib/sessions/cookie"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
@ -13,6 +10,10 @@ import (
|
|||||||
"one-api/router"
|
"one-api/router"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-contrib/sessions/cookie"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed web/build
|
//go:embed web/build
|
||||||
@ -82,7 +83,7 @@ func main() {
|
|||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
model.InitBatchUpdater()
|
model.InitBatchUpdater()
|
||||||
}
|
}
|
||||||
controller.InitTokenEncoders()
|
common.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.New()
|
server := gin.New()
|
||||||
|
@ -80,7 +80,8 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Set("api_key", channel.Key)
|
||||||
|
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gorm.io/gorm"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
|
50
providers/ali_base.go
Normal file
50
providers/ali_base.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AliAIProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 AliAIProvider
|
||||||
|
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||||
|
func CreateAliAIProvider(c *gin.Context) *AliAIProvider {
|
||||||
|
return &AliAIProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "https://dashscope.aliyuncs.com",
|
||||||
|
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
|
||||||
|
Embeddings: "/api/v1/services/embeddings/text-embedding/text-embedding",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *AliAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||||
|
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
256
providers/ali_chat.go
Normal file
256
providers/ali_chat.go
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AliMessage struct {
|
||||||
|
User string `json:"user"`
|
||||||
|
Bot string `json:"bot"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliInput struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
History []AliMessage `json:"history"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliParameters struct {
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Seed uint64 `json:"seed,omitempty"`
|
||||||
|
EnableSearch bool `json:"enable_search,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input AliInput `json:"input"`
|
||||||
|
Parameters AliParameters `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliOutput struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliChatResponse struct {
|
||||||
|
Output AliOutput `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aliResponse *AliChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if aliResponse.Code != "" {
|
||||||
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: aliResponse.Message,
|
||||||
|
Type: aliResponse.Code,
|
||||||
|
Param: aliResponse.RequestId,
|
||||||
|
Code: aliResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := types.ChatCompletionChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: aliResponse.Output.Text,
|
||||||
|
},
|
||||||
|
FinishReason: aliResponse.Output.FinishReason,
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
ID: aliResponse.RequestId,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
|
Usage: &types.Usage{
|
||||||
|
PromptTokens: aliResponse.Usage.InputTokens,
|
||||||
|
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||||
|
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullTextResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliAIProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||||
|
messages := make([]AliMessage, 0, len(request.Messages))
|
||||||
|
prompt := ""
|
||||||
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
|
message := request.Messages[i]
|
||||||
|
if message.Role == "system" {
|
||||||
|
messages = append(messages, AliMessage{
|
||||||
|
User: message.StringContent(),
|
||||||
|
Bot: "Okay",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
if i == len(request.Messages)-1 {
|
||||||
|
prompt = message.StringContent()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
messages = append(messages, AliMessage{
|
||||||
|
User: message.StringContent(),
|
||||||
|
Bot: request.Messages[i+1].StringContent(),
|
||||||
|
})
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &AliChatRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Input: AliInput{
|
||||||
|
Prompt: prompt,
|
||||||
|
History: messages,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
headers["X-DashScope-SSE"] = "enable"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage == nil {
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: 0,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
aliResponse := &AliChatResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, aliResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: aliResponse.Usage.InputTokens,
|
||||||
|
CompletionTokens: aliResponse.Usage.OutputTokens,
|
||||||
|
TotalTokens: aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliAIProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
choice.Delta.Content = aliResponse.Output.Text
|
||||||
|
if aliResponse.Output.FinishReason != "null" {
|
||||||
|
finishReason := aliResponse.Output.FinishReason
|
||||||
|
choice.FinishReason = &finishReason
|
||||||
|
}
|
||||||
|
|
||||||
|
response := types.ChatCompletionStreamResponse{
|
||||||
|
ID: aliResponse.RequestId,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "ernie-bot",
|
||||||
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliAIProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
||||||
|
usage = &types.Usage{}
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
lastResponseText := ""
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var aliResponse AliChatResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if aliResponse.Usage.OutputTokens != 0 {
|
||||||
|
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||||
|
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||||
|
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||||
|
}
|
||||||
|
response := p.streamResponseAli2OpenAI(&aliResponse)
|
||||||
|
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||||
|
lastResponseText = aliResponse.Output.Text
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
94
providers/ali_embeddings.go
Normal file
94
providers/ali_embeddings.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AliEmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters *struct {
|
||||||
|
TextType string `json:"text_type,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbedding struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
TextIndex int `json:"text_index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Embeddings []AliEmbedding `json:"embeddings"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aliResponse *AliEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if aliResponse.Code != "" {
|
||||||
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: aliResponse.Message,
|
||||||
|
Type: aliResponse.Code,
|
||||||
|
Param: aliResponse.RequestId,
|
||||||
|
Code: aliResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]types.Embedding, 0, len(aliResponse.Output.Embeddings)),
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Usage: &types.Usage{TotalTokens: aliResponse.Usage.TotalTokens},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range aliResponse.Output.Embeddings {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||||
|
Object: `embedding`,
|
||||||
|
Index: item.TextIndex,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return openAIEmbeddingResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliAIProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *AliEmbeddingRequest {
|
||||||
|
return &AliEmbeddingRequest{
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Input: struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
}{
|
||||||
|
Texts: request.ParseInput(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
requestBody := p.getEmbeddingsRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
aliEmbeddingResponse := &AliEmbeddingResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, aliEmbeddingResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usage = &types.Usage{TotalTokens: aliEmbeddingResponse.Usage.TotalTokens}
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
14
providers/api2d_base.go
Normal file
14
providers/api2d_base.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
type Api2dProvider struct {
|
||||||
|
*OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 OpenAIProvider
|
||||||
|
func CreateApi2dProvider(c *gin.Context) *Api2dProvider {
|
||||||
|
return &Api2dProvider{
|
||||||
|
OpenAIProvider: CreateOpenAIProvider(c, "https://oa.api2d.net"),
|
||||||
|
}
|
||||||
|
}
|
41
providers/azure_base.go
Normal file
41
providers/azure_base.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AzureProvider struct {
|
||||||
|
OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 OpenAIProvider
|
||||||
|
func CreateAzureProvider(c *gin.Context) *AzureProvider {
|
||||||
|
return &AzureProvider{
|
||||||
|
OpenAIProvider: OpenAIProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "",
|
||||||
|
Completions: "/completions",
|
||||||
|
ChatCompletions: "/chat/completions",
|
||||||
|
Embeddings: "/embeddings",
|
||||||
|
AudioSpeech: "/audio/speech",
|
||||||
|
AudioTranscriptions: "/audio/transcriptions",
|
||||||
|
AudioTranslations: "/audio/translations",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
isAzure: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// // 获取完整请求 URL
|
||||||
|
// func (p *AzureProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
// apiVersion := p.Context.GetString("api_version")
|
||||||
|
// requestURL = fmt.Sprintf("/openai/deployments/%s/%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||||
|
// baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
// if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
|
// requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
// }
|
136
providers/baidu_base.go
Normal file
136
providers/baidu_base.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
|
type BaiduProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduAccessToken struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
ErrorDescription string `json:"error_description,omitempty"`
|
||||||
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
|
||||||
|
return &BaiduProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "https://aip.baidubce.com",
|
||||||
|
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||||
|
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取完整请求 URL
|
||||||
|
func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
var modelNameMap = map[string]string{
|
||||||
|
"ERNIE-Bot": "completions",
|
||||||
|
"ERNIE-Bot-turbo": "eb-instant",
|
||||||
|
"ERNIE-Bot-4": "completions_pro",
|
||||||
|
"BLOOMZ-7B": "bloomz_7b1",
|
||||||
|
"Embedding-V1": "embedding-v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
apiKey, err := p.getBaiduAccessToken()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
||||||
|
apiKey := p.Context.GetString("api_key")
|
||||||
|
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||||
|
var accessToken BaiduAccessToken
|
||||||
|
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||||
|
// soon this will expire
|
||||||
|
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||||
|
go func() {
|
||||||
|
_, _ = p.getBaiduAccessTokenHelper(apiKey)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
return accessToken.AccessToken, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
accessToken, err := p.getBaiduAccessTokenHelper(apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if accessToken == nil {
|
||||||
|
return "", errors.New("getBaiduAccessToken return a nil token")
|
||||||
|
}
|
||||||
|
return (*accessToken).AccessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||||
|
parts := strings.Split(apiKey, "|")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, errors.New("invalid baidu apikey")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
|
||||||
|
|
||||||
|
var headers = map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var accessToken BaiduAccessToken
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accessToken.Error != "" {
|
||||||
|
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||||
|
}
|
||||||
|
if accessToken.AccessToken == "" {
|
||||||
|
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||||
|
}
|
||||||
|
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||||
|
baiduTokenStore.Store(apiKey, accessToken)
|
||||||
|
return &accessToken, nil
|
||||||
|
}
|
228
providers/baidu_chat.go
Normal file
228
providers/baidu_chat.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaiduMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatRequest struct {
|
||||||
|
Messages []BaiduMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
IsTruncated bool `json:"is_truncated"`
|
||||||
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
|
Usage *types.Usage `json:"usage"`
|
||||||
|
BaiduError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (baiduResponse *BaiduChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if baiduResponse.ErrorMsg != "" {
|
||||||
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: baiduResponse.ErrorMsg,
|
||||||
|
Type: "baidu_error",
|
||||||
|
Param: "",
|
||||||
|
Code: baiduResponse.ErrorCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := types.ChatCompletionChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: baiduResponse.Result,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
ID: baiduResponse.Id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: baiduResponse.Created,
|
||||||
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
|
Usage: baiduResponse.Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullTextResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatStreamResponse struct {
|
||||||
|
BaiduChatResponse
|
||||||
|
SentenceId int `json:"sentence_id"`
|
||||||
|
IsEnd bool `json:"is_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduError struct {
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
ErrorMsg string `json:"error_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaiduChatRequest {
|
||||||
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
messages = append(messages, BaiduMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
messages = append(messages, BaiduMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Okay",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
messages = append(messages, BaiduMessage{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &BaiduChatRequest{
|
||||||
|
Messages: messages,
|
||||||
|
Stream: request.Stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
if fullRequestURL == "" {
|
||||||
|
return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
baiduChatRequest := &BaiduChatResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, baiduChatRequest)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = baiduChatRequest.Usage
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
|
||||||
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
choice.Delta.Content = baiduResponse.Result
|
||||||
|
if baiduResponse.IsEnd {
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
}
|
||||||
|
|
||||||
|
response := types.ChatCompletionStreamResponse{
|
||||||
|
ID: baiduResponse.Id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: baiduResponse.Created,
|
||||||
|
Model: "ernie-bot",
|
||||||
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, usage *types.Usage) {
|
||||||
|
usage = &types.Usage{}
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[6:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var baiduResponse BaiduChatStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if baiduResponse.Usage.TotalTokens != 0 {
|
||||||
|
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||||
|
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||||
|
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||||
|
}
|
||||||
|
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
88
providers/baidu_embeddings.go
Normal file
88
providers/baidu_embeddings.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaiduEmbeddingRequest struct {
|
||||||
|
Input []string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingData struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduEmbeddingResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []BaiduEmbeddingData `json:"data"`
|
||||||
|
Usage types.Usage `json:"usage"`
|
||||||
|
BaiduError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) getEmbeddingsRequestBody(request *types.EmbeddingRequest) *BaiduEmbeddingRequest {
|
||||||
|
return &BaiduEmbeddingRequest{
|
||||||
|
Input: request.ParseInput(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (baiduResponse *BaiduEmbeddingResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if baiduResponse.ErrorMsg != "" {
|
||||||
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: baiduResponse.ErrorMsg,
|
||||||
|
Type: "baidu_error",
|
||||||
|
Param: "",
|
||||||
|
Code: baiduResponse.ErrorCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIEmbeddingResponse := &types.EmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]types.Embedding, 0, len(baiduResponse.Data)),
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Usage: &baiduResponse.Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range baiduResponse.Data {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, types.Embedding{
|
||||||
|
Object: item.Object,
|
||||||
|
Index: item.Index,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return openAIEmbeddingResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BaiduProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
requestBody := p.getEmbeddingsRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||||
|
if fullRequestURL == "" {
|
||||||
|
return nil, types.ErrorWrapper(nil, "invalid_baidu_config", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
baiduEmbeddingResponse := &BaiduEmbeddingResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, baiduEmbeddingResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usage = &baiduEmbeddingResponse.Usage
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
150
providers/base.go
Normal file
150
providers/base.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/types"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stopFinishReason = "stop"
|
||||||
|
|
||||||
|
type ProviderConfig struct {
|
||||||
|
BaseURL string
|
||||||
|
Completions string
|
||||||
|
ChatCompletions string
|
||||||
|
Embeddings string
|
||||||
|
AudioSpeech string
|
||||||
|
AudioTranscriptions string
|
||||||
|
AudioTranslations string
|
||||||
|
Proxy string
|
||||||
|
Context *gin.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseProviderAction interface {
|
||||||
|
GetBaseURL() string
|
||||||
|
GetFullRequestURL(requestURL string, modelName string) string
|
||||||
|
GetRequestHeaders() (headers map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionProviderAction interface {
|
||||||
|
BaseProviderAction
|
||||||
|
CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatProviderAction interface {
|
||||||
|
BaseProviderAction
|
||||||
|
ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingsProviderAction interface {
|
||||||
|
BaseProviderAction
|
||||||
|
EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type BalanceProviderAction interface {
|
||||||
|
Balance(channel *model.Channel) (float64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProviderConfig) GetBaseURL() string {
|
||||||
|
if p.Context.GetString("base_url") != "" {
|
||||||
|
return p.Context.GetString("base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.BaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProviderConfig) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setEventStreamHeaders(c *gin.Context) {
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProviderConfig) handleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||||
|
Type: "upstream_error",
|
||||||
|
Code: "bad_response_status_code",
|
||||||
|
Param: strconv.Itoa(resp.StatusCode),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var errorResponse types.OpenAIErrorResponse
|
||||||
|
err = json.Unmarshal(responseBody, &errorResponse)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errorResponse.Error.Type != "" {
|
||||||
|
openAIErrorWithStatusCode.OpenAIError = errorResponse.Error
|
||||||
|
} else {
|
||||||
|
openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 供应商响应处理函数
|
||||||
|
type ProviderResponseHandler interface {
|
||||||
|
// 请求处理函数
|
||||||
|
requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
func (p *ProviderConfig) sendRequest(req *http.Request, response ProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 处理响应
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
err = common.DecodeResponse(resp.Body, response)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIResponse, openAIErrorWithStatusCode := response.requestHandler(resp)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(openAIResponse)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = p.Context.Writer.Write(jsonResponse)
|
||||||
|
return nil
|
||||||
|
}
|
55
providers/claude_base.go
Normal file
55
providers/claude_base.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeError struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
|
||||||
|
return &ClaudeProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "https://api.anthropic.com",
|
||||||
|
ChatCompletions: "/v1/complete",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
|
||||||
|
headers["x-api-key"] = p.Context.GetString("api_key")
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||||
|
if anthropicVersion == "" {
|
||||||
|
anthropicVersion = "2023-06-01"
|
||||||
|
}
|
||||||
|
headers["anthropic-version"] = anthropicVersion
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
|
switch reason {
|
||||||
|
case "stop_sequence":
|
||||||
|
return "stop"
|
||||||
|
case "max_tokens":
|
||||||
|
return "length"
|
||||||
|
default:
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
}
|
232
providers/claude_chat.go
Normal file
232
providers/claude_chat.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeMetadata struct {
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeResponse struct {
|
||||||
|
Completion string `json:"completion"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Error ClaudeError `json:"error"`
|
||||||
|
Usage *types.Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (claudeResponse *ClaudeResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if claudeResponse.Error.Type != "" {
|
||||||
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: claudeResponse.Error.Message,
|
||||||
|
Type: claudeResponse.Error.Type,
|
||||||
|
Param: "",
|
||||||
|
Code: claudeResponse.Error.Type,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := types.ChatCompletionChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||||
|
Name: nil,
|
||||||
|
},
|
||||||
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||||
|
}
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
|
}
|
||||||
|
|
||||||
|
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
|
||||||
|
claudeResponse.Usage.CompletionTokens = completionTokens
|
||||||
|
claudeResponse.Usage.TotalTokens = claudeResponse.Usage.PromptTokens + completionTokens
|
||||||
|
|
||||||
|
fullTextResponse.Usage = claudeResponse.Usage
|
||||||
|
|
||||||
|
return fullTextResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *ClaudeRequest) {
|
||||||
|
claudeRequest := ClaudeRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Prompt: "",
|
||||||
|
MaxTokensToSample: request.MaxTokens,
|
||||||
|
StopSequences: nil,
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
Stream: request.Stream,
|
||||||
|
}
|
||||||
|
if claudeRequest.MaxTokensToSample == 0 {
|
||||||
|
claudeRequest.MaxTokensToSample = 1000000
|
||||||
|
}
|
||||||
|
prompt := ""
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "user" {
|
||||||
|
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||||
|
} else if message.Role == "assistant" {
|
||||||
|
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||||
|
} else if message.Role == "system" {
|
||||||
|
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prompt += "\n\nAssistant:"
|
||||||
|
claudeRequest.Prompt = prompt
|
||||||
|
return &claudeRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
var responseText string
|
||||||
|
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.PromptTokens = promptTokens
|
||||||
|
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
|
||||||
|
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
|
} else {
|
||||||
|
var claudeResponse = &ClaudeResponse{
|
||||||
|
Usage: &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, claudeResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = claudeResponse.Usage
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *types.ChatCompletionStreamResponse {
|
||||||
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
choice.Delta.Content = claudeResponse.Completion
|
||||||
|
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||||
|
if finishReason != "null" {
|
||||||
|
choice.FinishReason = &finishReason
|
||||||
|
}
|
||||||
|
var response types.ChatCompletionStreamResponse
|
||||||
|
response.Object = "chat.completion.chunk"
|
||||||
|
response.Model = claudeResponse.Model
|
||||||
|
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
responseText := ""
|
||||||
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||||
|
return i + 4, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if !strings.HasPrefix(data, "event: completion") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
// some implementations may add \r at the end of data
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
var claudeResponse ClaudeResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
responseText += claudeResponse.Completion
|
||||||
|
response := p.streamResponseClaude2OpenAI(&claudeResponse)
|
||||||
|
response.ID = responseId
|
||||||
|
response.Created = createdTime
|
||||||
|
jsonStr, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, responseText
|
||||||
|
}
|
50
providers/closeai_proxy_base.go
Normal file
50
providers/closeai_proxy_base.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CloseaiProxyProvider struct {
|
||||||
|
*OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAICreditGrants struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
TotalGranted float64 `json:"total_granted"`
|
||||||
|
TotalUsed float64 `json:"total_used"`
|
||||||
|
TotalAvailable float64 `json:"total_available"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 CloseaiProxyProvider
|
||||||
|
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider {
|
||||||
|
return &CloseaiProxyProvider{
|
||||||
|
OpenAIProvider: CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CloseaiProxyProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||||
|
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
var response OpenAICreditGrants
|
||||||
|
err = client.SendRequest(req, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.UpdateBalance(response.TotalAvailable)
|
||||||
|
|
||||||
|
return response.TotalAvailable, nil
|
||||||
|
}
|
215
providers/openai_base.go
Normal file
215
providers/openai_base.go
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
isAzure bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderResponseHandler interface {
|
||||||
|
// 请求处理函数
|
||||||
|
requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderStreamResponseHandler interface {
|
||||||
|
// 请求流处理函数
|
||||||
|
requestStreamHandler() (responseText string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 OpenAIProvider
|
||||||
|
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OpenAIProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
Completions: "/v1/completions",
|
||||||
|
ChatCompletions: "/v1/chat/completions",
|
||||||
|
Embeddings: "/v1/embeddings",
|
||||||
|
AudioSpeech: "/v1/audio/speech",
|
||||||
|
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||||
|
AudioTranslations: "/v1/audio/translations",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
isAzure: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取完整请求 URL
|
||||||
|
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
if p.isAzure {
|
||||||
|
apiVersion := p.Context.GetString("api_version")
|
||||||
|
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
|
if p.isAzure {
|
||||||
|
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
||||||
|
} else {
|
||||||
|
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
if p.isAzure {
|
||||||
|
headers["api-key"] = p.Context.GetString("api_key")
|
||||||
|
} else {
|
||||||
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
||||||
|
}
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json; charset=utf-8"
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求体
|
||||||
|
func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
|
||||||
|
if isModelMapped {
|
||||||
|
jsonStr, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
} else {
|
||||||
|
requestBody = p.Context.Request.Body
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
func (p *OpenAIProvider) sendRequest(req *http.Request, response OpenAIProviderResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 处理响应
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建一个 bytes.Buffer 来存储响应体
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tee := io.TeeReader(resp.Body, &buf)
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
err = common.DecodeResponse(tee, response)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIErrorWithStatusCode = response.requestHandler(resp)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
p.Context.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Context.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = io.Copy(p.Context.Writer, &buf)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||||
|
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataChan <- data
|
||||||
|
data = data[6:]
|
||||||
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
err := json.Unmarshal([]byte(data), response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
continue // just ignore the error
|
||||||
|
}
|
||||||
|
responseText += response.requestStreamHandler()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
if strings.HasPrefix(data, "data: [DONE]") {
|
||||||
|
data = data[:12]
|
||||||
|
}
|
||||||
|
// some implementations may add \r at the end of data
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: data})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, responseText
|
||||||
|
}
|
92
providers/openai_chat.go
Normal file
92
providers/openai_chat.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProviderChatResponse struct {
|
||||||
|
types.ChatCompletionResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIProviderChatStreamResponse struct {
|
||||||
|
types.ChatCompletionStreamResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIProviderChatResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if c.Error.Type != "" {
|
||||||
|
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: c.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIProviderChatStreamResponse) requestStreamHandler() (responseText string) {
|
||||||
|
for _, choice := range c.Choices {
|
||||||
|
responseText += choice.Delta.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream && headers["Accept"] == "" {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||||
|
var textResponse string
|
||||||
|
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||||
|
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderChatResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = openAIProviderChatResponse.Usage
|
||||||
|
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
completionTokens := 0
|
||||||
|
for _, choice := range openAIProviderChatResponse.Choices {
|
||||||
|
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
|
||||||
|
}
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: promptTokens + completionTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
87
providers/openai_completion.go
Normal file
87
providers/openai_completion.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProviderCompletionResponse struct {
|
||||||
|
types.CompletionResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIProviderCompletionResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if c.Error.Type != "" {
|
||||||
|
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: c.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIProviderCompletionResponse) requestStreamHandler() (responseText string) {
|
||||||
|
for _, choice := range c.Choices {
|
||||||
|
responseText += choice.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) CompleteResponse(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream && headers["Accept"] == "" {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
|
||||||
|
if request.Stream {
|
||||||
|
// TODO
|
||||||
|
var textResponse string
|
||||||
|
openAIErrorWithStatusCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||||
|
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderCompletionResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = openAIProviderCompletionResponse.Usage
|
||||||
|
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
completionTokens := 0
|
||||||
|
for _, choice := range openAIProviderCompletionResponse.Choices {
|
||||||
|
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
|
||||||
|
}
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: promptTokens + completionTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
50
providers/openai_embeddings.go
Normal file
50
providers/openai_embeddings.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProviderEmbeddingsResponse struct {
|
||||||
|
types.EmbeddingResponse
|
||||||
|
types.OpenAIErrorResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIProviderEmbeddingsResponse) requestHandler(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if c.Error.Type != "" {
|
||||||
|
openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: c.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) EmbeddingsResponse(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
|
||||||
|
requestBody, err := p.getRequestBody(&request, isModelMapped)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, openAIProviderEmbeddingsResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = openAIProviderEmbeddingsResponse.Usage
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
58
providers/openaisb_base.go
Normal file
58
providers/openaisb_base.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenaiSBProvider struct {
|
||||||
|
*OpenAIProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAISBUsageResponse struct {
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data *struct {
|
||||||
|
Credit string `json:"credit"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 OpenaiSBProvider
|
||||||
|
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider {
|
||||||
|
return &OpenaiSBProvider{
|
||||||
|
OpenAIProvider: CreateOpenAIProvider(c, "https://api.openai-sb.com"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenaiSBProvider) Balance(channel *model.Channel) (float64, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL("/sb-api/user/status", "")
|
||||||
|
fullRequestURL = fmt.Sprintf("%s?api_key=%s", fullRequestURL, channel.Key)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest("GET", fullRequestURL, common.WithBody(nil), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
var response OpenAISBUsageResponse
|
||||||
|
err = client.SendRequest(req, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Data == nil {
|
||||||
|
return 0, errors.New(response.Msg)
|
||||||
|
}
|
||||||
|
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
channel.UpdateBalance(balance)
|
||||||
|
return balance, nil
|
||||||
|
}
|
43
providers/palm_base.go
Normal file
43
providers/palm_base.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PalmProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 PalmProvider
|
||||||
|
func CreatePalmProvider(c *gin.Context) *PalmProvider {
|
||||||
|
return &PalmProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "https://generativelanguage.googleapis.com",
|
||||||
|
ChatCompletions: "/v1beta2/models/chat-bison-001:generateMessage",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取完整请求 URL
|
||||||
|
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key"))
|
||||||
|
}
|
232
providers/palm_chat.go
Normal file
232
providers/palm_chat.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PaLMChatMessage struct {
|
||||||
|
Author string `json:"author"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMFilter struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMPrompt struct {
|
||||||
|
Messages []PaLMChatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMChatRequest struct {
|
||||||
|
Prompt PaLMPrompt `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
TopK int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMChatResponse struct {
|
||||||
|
Candidates []PaLMChatMessage `json:"candidates"`
|
||||||
|
Messages []types.ChatCompletionMessage `json:"messages"`
|
||||||
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
Error PaLMError `json:"error"`
|
||||||
|
Usage *types.Usage `json:"usage,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (palmResponse *PaLMChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||||
|
return nil, &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: palmResponse.Error.Message,
|
||||||
|
Type: palmResponse.Error.Status,
|
||||||
|
Param: "",
|
||||||
|
Code: palmResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
Choices: make([]types.ChatCompletionChoice, 0, len(palmResponse.Candidates)),
|
||||||
|
}
|
||||||
|
for i, candidate := range palmResponse.Candidates {
|
||||||
|
choice := types.ChatCompletionChoice{
|
||||||
|
Index: i,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: candidate.Content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
}
|
||||||
|
|
||||||
|
completionTokens := common.CountTokenText(palmResponse.Candidates[0].Content, palmResponse.Model)
|
||||||
|
palmResponse.Usage.CompletionTokens = completionTokens
|
||||||
|
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
|
||||||
|
|
||||||
|
fullTextResponse.Usage = palmResponse.Usage
|
||||||
|
|
||||||
|
return fullTextResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PalmProvider) getChatRequestBody(request *types.ChatCompletionRequest) *PaLMChatRequest {
|
||||||
|
palmRequest := PaLMChatRequest{
|
||||||
|
Prompt: PaLMPrompt{
|
||||||
|
Messages: make([]PaLMChatMessage, 0, len(request.Messages)),
|
||||||
|
},
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
CandidateCount: request.N,
|
||||||
|
TopP: request.TopP,
|
||||||
|
TopK: request.MaxTokens,
|
||||||
|
}
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
palmMessage := PaLMChatMessage{
|
||||||
|
Content: message.StringContent(),
|
||||||
|
}
|
||||||
|
if message.Role == "user" {
|
||||||
|
palmMessage.Author = "0"
|
||||||
|
} else {
|
||||||
|
palmMessage.Author = "1"
|
||||||
|
}
|
||||||
|
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||||
|
}
|
||||||
|
return &palmRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PalmProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
var responseText string
|
||||||
|
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.PromptTokens = promptTokens
|
||||||
|
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
|
||||||
|
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
|
} else {
|
||||||
|
var palmChatResponse = &PaLMChatResponse{
|
||||||
|
Model: request.Model,
|
||||||
|
Usage: &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, palmChatResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = palmChatResponse.Usage
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
if len(palmResponse.Candidates) > 0 {
|
||||||
|
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||||
|
}
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
var response types.ChatCompletionStreamResponse
|
||||||
|
response.Object = "chat.completion.chunk"
|
||||||
|
response.Model = "palm2"
|
||||||
|
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
responseText := ""
|
||||||
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error reading stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error closing stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var palmResponse PaLMChatResponse
|
||||||
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fullTextResponse := p.streamResponsePaLM2OpenAI(&palmResponse)
|
||||||
|
fullTextResponse.ID = responseId
|
||||||
|
fullTextResponse.Created = createdTime
|
||||||
|
if len(palmResponse.Candidates) > 0 {
|
||||||
|
responseText = palmResponse.Candidates[0].Content
|
||||||
|
}
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataChan <- string(jsonResponse)
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, responseText
|
||||||
|
}
|
94
providers/tencent_base.go
Normal file
94
providers/tencent_base.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TencentProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 TencentProvider
|
||||||
|
func CreateTencentProvider(c *gin.Context) *TencentProvider {
|
||||||
|
return &TencentProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "https://hunyuan.cloud.tencent.com",
|
||||||
|
ChatCompletions: "/hyllm/v1/chat/completions",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||||
|
parts := strings.Split(config, "|")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
err = errors.New("invalid tencent config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
secretId = parts[1]
|
||||||
|
secretKey = parts[2]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
||||||
|
apiKey := p.Context.GetString("api_key")
|
||||||
|
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
req.AppId = appId
|
||||||
|
req.SecretId = secretId
|
||||||
|
|
||||||
|
params := make([]string, 0)
|
||||||
|
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||||
|
params = append(params, "secret_id="+req.SecretId)
|
||||||
|
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||||
|
params = append(params, "query_id="+req.QueryID)
|
||||||
|
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||||
|
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||||
|
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||||
|
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||||
|
|
||||||
|
var messageStr string
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||||
|
}
|
||||||
|
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||||
|
params = append(params, "messages=["+messageStr+"]")
|
||||||
|
|
||||||
|
sort.Sort(sort.StringSlice(params))
|
||||||
|
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||||
|
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||||
|
signURL := url
|
||||||
|
mac.Write([]byte(signURL))
|
||||||
|
sign := mac.Sum([]byte(nil))
|
||||||
|
return base64.StdEncoding.EncodeToString(sign)
|
||||||
|
}
|
@ -1,24 +1,16 @@
|
|||||||
package controller
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sort"
|
"one-api/types"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://cloud.tencent.com/document/product/1729/97732
|
|
||||||
|
|
||||||
type TencentMessage struct {
|
type TencentMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
@ -50,11 +42,6 @@ type TencentChatRequest struct {
|
|||||||
Messages []TencentMessage `json:"messages"`
|
Messages []TencentMessage `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TencentError struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TencentUsage struct {
|
type TencentUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
@ -71,13 +58,44 @@ type TencentChatResponse struct {
|
|||||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||||
Id string `json:"id,omitempty"` // 会话 id
|
Id string `json:"id,omitempty"` // 会话 id
|
||||||
Usage Usage `json:"usage,omitempty"` // token 数量
|
Usage *types.Usage `json:"usage,omitempty"` // token 数量
|
||||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
Note string `json:"note,omitempty"` // 注释
|
Note string `json:"note,omitempty"` // 注释
|
||||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
func (TencentResponse *TencentChatResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if TencentResponse.Error.Code != 0 {
|
||||||
|
return &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: TencentResponse.Error.Message,
|
||||||
|
Code: TencentResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Usage: TencentResponse.Usage,
|
||||||
|
}
|
||||||
|
if len(TencentResponse.Choices) > 0 {
|
||||||
|
choice := types.ChatCompletionChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: TencentResponse.Choices[0].Messages.Content,
|
||||||
|
},
|
||||||
|
FinishReason: TencentResponse.Choices[0].FinishReason,
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullTextResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TencentProvider) getChatRequestBody(request *types.ChatCompletionRequest) *TencentChatRequest {
|
||||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||||
for i := 0; i < len(request.Messages); i++ {
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
message := request.Messages[i]
|
message := request.Messages[i]
|
||||||
@ -112,34 +130,58 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
func (p *TencentProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
fullTextResponse := OpenAITextResponse{
|
requestBody := p.getChatRequestBody(request)
|
||||||
Object: "chat.completion",
|
sign := p.getTencentSign(*requestBody)
|
||||||
Created: common.GetTimestamp(),
|
if sign == "" {
|
||||||
Usage: response.Usage,
|
return nil, types.ErrorWrapper(errors.New("get tencent sign failed"), "get_tencent_sign_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if len(response.Choices) > 0 {
|
|
||||||
choice := OpenAITextResponseChoice{
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
Index: 0,
|
headers := p.GetRequestHeaders()
|
||||||
Message: Message{
|
headers["Authorization"] = sign
|
||||||
Role: "assistant",
|
if request.Stream {
|
||||||
Content: response.Choices[0].Messages.Content,
|
headers["Accept"] = "text/event-stream"
|
||||||
},
|
}
|
||||||
FinishReason: response.Choices[0].FinishReason,
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
var responseText string
|
||||||
|
openAIErrorWithStatusCode, responseText = p.sendStreamRequest(req)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
||||||
|
usage.PromptTokens = promptTokens
|
||||||
|
usage.CompletionTokens = common.CountTokenText(responseText, request.Model)
|
||||||
|
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
|
} else {
|
||||||
|
tencentResponse := &TencentChatResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, tencentResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = tencentResponse.Usage
|
||||||
}
|
}
|
||||||
return &fullTextResponse
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
|
func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
response := ChatCompletionsStreamResponse{
|
response := types.ChatCompletionStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: "tencent-hunyuan",
|
Model: "tencent-hunyuan",
|
||||||
}
|
}
|
||||||
if len(TencentResponse.Choices) > 0 {
|
if len(TencentResponse.Choices) > 0 {
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||||
choice.FinishReason = &stopFinishReason
|
choice.FinishReason = &stopFinishReason
|
||||||
@ -149,7 +191,19 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCom
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), ""
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var responseText string
|
var responseText string
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
@ -180,8 +234,8 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
setEventStreamHeaders(p.Context)
|
||||||
c.Stream(func(w io.Writer) bool {
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
var TencentResponse TencentChatResponse
|
var TencentResponse TencentChatResponse
|
||||||
@ -190,7 +244,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
|||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
response := p.streamResponseTencent2OpenAI(&TencentResponse)
|
||||||
if len(response.Choices) != 0 {
|
if len(response.Choices) != 0 {
|
||||||
responseText += response.Choices[0].Delta.Content
|
responseText += response.Choices[0].Delta.Content
|
||||||
}
|
}
|
||||||
@ -199,89 +253,13 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
|||||||
common.SysError("error marshalling stream response: " + err.Error())
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
||||||
}
|
|
||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
var TencentResponse TencentChatResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if TencentResponse.Error.Code != 0 {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: TencentResponse.Error.Message,
|
|
||||||
Code: TencentResponse.Error.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
|
||||||
parts := strings.Split(config, "|")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
err = errors.New("invalid tencent config")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
|
||||||
secretId = parts[1]
|
|
||||||
secretKey = parts[2]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
|
||||||
params := make([]string, 0)
|
|
||||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
|
||||||
params = append(params, "secret_id="+req.SecretId)
|
|
||||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
|
||||||
params = append(params, "query_id="+req.QueryID)
|
|
||||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
|
||||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
|
||||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
|
||||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
|
||||||
|
|
||||||
var messageStr string
|
|
||||||
for _, msg := range req.Messages {
|
|
||||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
|
||||||
}
|
|
||||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
|
||||||
params = append(params, "messages=["+messageStr+"]")
|
|
||||||
|
|
||||||
sort.Sort(sort.StringSlice(params))
|
|
||||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
|
||||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
|
||||||
signURL := url
|
|
||||||
mac.Write([]byte(signURL))
|
|
||||||
sign := mac.Sum([]byte(nil))
|
|
||||||
return base64.StdEncoding.EncodeToString(sign)
|
|
||||||
}
|
|
96
providers/xunfei_base.go
Normal file
96
providers/xunfei_base.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://www.xfyun.cn/doc/spark/Web.html
|
||||||
|
type XunfeiProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
domain string
|
||||||
|
apiId string
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 XunfeiProvider
|
||||||
|
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
|
||||||
|
return &XunfeiProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "wss://spark-api.xf-yun.com",
|
||||||
|
ChatCompletions: "",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取完整请求 URL
|
||||||
|
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
splits := strings.Split(p.Context.GetString("api_key"), "|")
|
||||||
|
if len(splits) != 3 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
domain, authUrl := p.getXunfeiAuthUrl(splits[2], splits[1])
|
||||||
|
|
||||||
|
p.domain = domain
|
||||||
|
p.apiId = splits[0]
|
||||||
|
|
||||||
|
return authUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (string, string) {
|
||||||
|
query := p.Context.Request.URL.Query()
|
||||||
|
apiVersion := query.Get("api-version")
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = p.Context.GetString("api_version")
|
||||||
|
}
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = "v1.1"
|
||||||
|
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||||
|
}
|
||||||
|
domain := "general"
|
||||||
|
if apiVersion != "v1.1" {
|
||||||
|
domain += strings.Split(apiVersion, ".")[0]
|
||||||
|
}
|
||||||
|
authUrl := p.buildXunfeiAuthUrl(fmt.Sprintf("%s/%s/chat", p.BaseURL, apiVersion), apiKey, apiSecret)
|
||||||
|
return domain, authUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *XunfeiProvider) buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||||
|
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
||||||
|
mac := hmac.New(sha256.New, []byte(key))
|
||||||
|
mac.Write([]byte(data))
|
||||||
|
encodeData := mac.Sum(nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(encodeData)
|
||||||
|
}
|
||||||
|
ul, err := url.Parse(hostUrl)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
date := time.Now().UTC().Format(time.RFC1123)
|
||||||
|
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||||
|
sign := strings.Join(signString, "\n")
|
||||||
|
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
||||||
|
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||||
|
"hmac-sha256", "host date request-line", sha)
|
||||||
|
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||||
|
v := url.Values{}
|
||||||
|
v.Add("host", ul.Host)
|
||||||
|
v.Add("date", date)
|
||||||
|
v.Add("authorization", authorization)
|
||||||
|
callUrl := hostUrl + "?" + v.Encode()
|
||||||
|
return callUrl
|
||||||
|
}
|
@ -1,23 +1,15 @@
|
|||||||
package controller
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
)
|
|
||||||
|
|
||||||
// https://console.xfyun.cn/services/cbm
|
"github.com/gorilla/websocket"
|
||||||
// https://www.xfyun.cn/doc/spark/Web.html
|
)
|
||||||
|
|
||||||
type XunfeiMessage struct {
|
type XunfeiMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
@ -70,150 +62,28 @@ type XunfeiChatResponse struct {
|
|||||||
// CompletionTokens string `json:"completion_tokens"`
|
// CompletionTokens string `json:"completion_tokens"`
|
||||||
// TotalTokens string `json:"total_tokens"`
|
// TotalTokens string `json:"total_tokens"`
|
||||||
//} `json:"text"`
|
//} `json:"text"`
|
||||||
Text Usage `json:"text"`
|
Text types.Usage `json:"text"`
|
||||||
} `json:"usage"`
|
} `json:"usage"`
|
||||||
} `json:"payload"`
|
} `json:"payload"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
func (p *XunfeiProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
if request.Stream {
|
||||||
messages = append(messages, XunfeiMessage{
|
return p.sendStreamRequest(request, authUrl)
|
||||||
Role: "user",
|
} else {
|
||||||
Content: message.StringContent(),
|
return p.sendRequest(request, authUrl)
|
||||||
})
|
|
||||||
messages = append(messages, XunfeiMessage{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "Okay",
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
messages = append(messages, XunfeiMessage{
|
|
||||||
Role: message.Role,
|
|
||||||
Content: message.StringContent(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
xunfeiRequest := XunfeiChatRequest{}
|
|
||||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
|
||||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
|
||||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
|
||||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
|
||||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
|
||||||
xunfeiRequest.Payload.Message.Text = messages
|
|
||||||
return &xunfeiRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
if len(response.Payload.Choices.Text) == 0 {
|
usage = &types.Usage{}
|
||||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||||
{
|
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
choice := OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
|
||||||
},
|
|
||||||
FinishReason: stopFinishReason,
|
|
||||||
}
|
|
||||||
fullTextResponse := OpenAITextResponse{
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Choices: []OpenAITextResponseChoice{choice},
|
|
||||||
Usage: response.Payload.Usage.Text,
|
|
||||||
}
|
|
||||||
return &fullTextResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
|
||||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
|
||||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
|
||||||
{
|
|
||||||
Content: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var choice ChatCompletionsStreamResponseChoice
|
|
||||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
|
||||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
|
||||||
choice.FinishReason = &stopFinishReason
|
|
||||||
}
|
|
||||||
response := ChatCompletionsStreamResponse{
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "SparkDesk",
|
|
||||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
|
||||||
}
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|
||||||
HmacWithShaToBase64 := func(algorithm, data, key string) string {
|
|
||||||
mac := hmac.New(sha256.New, []byte(key))
|
|
||||||
mac.Write([]byte(data))
|
|
||||||
encodeData := mac.Sum(nil)
|
|
||||||
return base64.StdEncoding.EncodeToString(encodeData)
|
|
||||||
}
|
|
||||||
ul, err := url.Parse(hostUrl)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
date := time.Now().UTC().Format(time.RFC1123)
|
|
||||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
|
||||||
sign := strings.Join(signString, "\n")
|
|
||||||
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
|
|
||||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
|
||||||
"hmac-sha256", "host date request-line", sha)
|
|
||||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
|
||||||
v := url.Values{}
|
|
||||||
v.Add("host", ul.Host)
|
|
||||||
v.Add("date", date)
|
|
||||||
v.Add("authorization", authorization)
|
|
||||||
callUrl := hostUrl + "?" + v.Encode()
|
|
||||||
return callUrl
|
|
||||||
}
|
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
setEventStreamHeaders(c)
|
|
||||||
var usage Usage
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case xunfeiResponse := <-dataChan:
|
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
||||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
||||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var usage Usage
|
|
||||||
var content string
|
var content string
|
||||||
var xunfeiResponse XunfeiChatResponse
|
var xunfeiResponse XunfeiChatResponse
|
||||||
stop := false
|
stop := false
|
||||||
@ -233,17 +103,100 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
|
|||||||
|
|
||||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||||
|
|
||||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
response := p.responseXunfei2OpenAI(&xunfeiResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
p.Context.Writer.Header().Set("Content-Type", "application/json")
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
_, _ = p.Context.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
usage = &types.Usage{}
|
||||||
|
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case xunfeiResponse := <-dataChan:
|
||||||
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionRequest) *XunfeiChatRequest {
|
||||||
|
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
messages = append(messages, XunfeiMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
messages = append(messages, XunfeiMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Okay",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
messages = append(messages, XunfeiMessage{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xunfeiRequest := XunfeiChatRequest{}
|
||||||
|
xunfeiRequest.Header.AppId = p.apiId
|
||||||
|
xunfeiRequest.Parameter.Chat.Domain = p.domain
|
||||||
|
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||||
|
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||||
|
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||||
|
xunfeiRequest.Payload.Message.Text = messages
|
||||||
|
return &xunfeiRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *types.ChatCompletionResponse {
|
||||||
|
if len(response.Payload.Choices.Text) == 0 {
|
||||||
|
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||||
|
{
|
||||||
|
Content: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
choice := types.ChatCompletionChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: response.Payload.Choices.Text[0].Content,
|
||||||
|
},
|
||||||
|
FinishReason: stopFinishReason,
|
||||||
|
}
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
|
Usage: &response.Payload.Usage.Text,
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequest, authUrl string) (chan XunfeiChatResponse, chan bool, error) {
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
@ -251,7 +204,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
|
|||||||
if err != nil || resp.StatusCode != 101 {
|
if err != nil || resp.StatusCode != 101 {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
data := p.requestOpenAI2Xunfei(textRequest)
|
||||||
err = conn.WriteJSON(data)
|
err = conn.WriteJSON(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -287,20 +240,24 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
|
|||||||
return dataChan, stopChan, nil
|
return dataChan, stopChan, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *types.ChatCompletionStreamResponse {
|
||||||
query := c.Request.URL.Query()
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||||
apiVersion := query.Get("api-version")
|
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||||
if apiVersion == "" {
|
{
|
||||||
apiVersion = c.GetString("api_version")
|
Content: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if apiVersion == "" {
|
var choice types.ChatCompletionStreamChoice
|
||||||
apiVersion = "v1.1"
|
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
}
|
}
|
||||||
domain := "general"
|
response := types.ChatCompletionStreamResponse{
|
||||||
if apiVersion != "v1.1" {
|
Object: "chat.completion.chunk",
|
||||||
domain += strings.Split(apiVersion, ".")[0]
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "SparkDesk",
|
||||||
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
return &response
|
||||||
return domain, authUrl
|
|
||||||
}
|
}
|
104
providers/zhipu_base.go
Normal file
104
providers/zhipu_base.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var zhipuTokens sync.Map
|
||||||
|
var expSeconds int64 = 24 * 3600
|
||||||
|
|
||||||
|
type ZhipuProvider struct {
|
||||||
|
ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type zhipuTokenData struct {
|
||||||
|
Token string
|
||||||
|
ExpiryTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 ZhipuProvider
|
||||||
|
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
|
||||||
|
return &ZhipuProvider{
|
||||||
|
ProviderConfig: ProviderConfig{
|
||||||
|
BaseURL: "https://open.bigmodel.cn",
|
||||||
|
ChatCompletions: "/api/paas/v3/model-api",
|
||||||
|
Context: c,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求头
|
||||||
|
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
|
||||||
|
headers["Authorization"] = p.getZhipuToken()
|
||||||
|
headers["Content-Type"] = p.Context.Request.Header.Get("Content-Type")
|
||||||
|
headers["Accept"] = p.Context.Request.Header.Get("Accept")
|
||||||
|
if headers["Content-Type"] == "" {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取完整请求 URL
|
||||||
|
func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ZhipuProvider) getZhipuToken() string {
|
||||||
|
apikey := p.Context.GetString("api_key")
|
||||||
|
data, ok := zhipuTokens.Load(apikey)
|
||||||
|
if ok {
|
||||||
|
tokenData := data.(zhipuTokenData)
|
||||||
|
if time.Now().Before(tokenData.ExpiryTime) {
|
||||||
|
return tokenData.Token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
split := strings.Split(apikey, ".")
|
||||||
|
if len(split) != 2 {
|
||||||
|
common.SysError("invalid zhipu key: " + apikey)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
id := split[0]
|
||||||
|
secret := split[1]
|
||||||
|
|
||||||
|
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||||
|
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||||
|
|
||||||
|
timestamp := time.Now().UnixNano() / 1e6
|
||||||
|
|
||||||
|
payload := jwt.MapClaims{
|
||||||
|
"api_key": id,
|
||||||
|
"exp": expMillis,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||||
|
|
||||||
|
token.Header["alg"] = "HS256"
|
||||||
|
token.Header["sign_type"] = "SIGN"
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
zhipuTokens.Store(apikey, zhipuTokenData{
|
||||||
|
Token: tokenString,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
})
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
260
providers/zhipu_chat.go
Normal file
260
providers/zhipu_chat.go
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZhipuMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuRequest struct {
|
||||||
|
Prompt []ZhipuMessage `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
RequestId string `json:"request_id,omitempty"`
|
||||||
|
Incremental bool `json:"incremental,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponseData struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
Choices []ZhipuMessage `json:"choices"`
|
||||||
|
types.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data ZhipuResponseData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuStreamMetaResponse struct {
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
types.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (zhipuResponse *ZhipuResponse) requestHandler(resp *http.Response) (OpenAIResponse any, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
if !zhipuResponse.Success {
|
||||||
|
return &types.OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: types.OpenAIError{
|
||||||
|
Message: zhipuResponse.Msg,
|
||||||
|
Type: "zhipu_error",
|
||||||
|
Param: "",
|
||||||
|
Code: zhipuResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := types.ChatCompletionResponse{
|
||||||
|
ID: zhipuResponse.Data.TaskId,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
|
||||||
|
Usage: &zhipuResponse.Data.Usage,
|
||||||
|
}
|
||||||
|
for i, choice := range zhipuResponse.Data.Choices {
|
||||||
|
openaiChoice := types.ChatCompletionChoice{
|
||||||
|
Index: i,
|
||||||
|
Message: types.ChatCompletionMessage{
|
||||||
|
Role: choice.Role,
|
||||||
|
Content: strings.Trim(choice.Content, "\""),
|
||||||
|
},
|
||||||
|
FinishReason: "",
|
||||||
|
}
|
||||||
|
if i == len(zhipuResponse.Data.Choices)-1 {
|
||||||
|
openaiChoice.FinishReason = "stop"
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||||
|
}
|
||||||
|
return fullTextResponse, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ZhipuProvider) getChatRequestBody(request *types.ChatCompletionRequest) *ZhipuRequest {
|
||||||
|
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
messages = append(messages, ZhipuMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
messages = append(messages, ZhipuMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Okay",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
messages = append(messages, ZhipuMessage{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &ZhipuRequest{
|
||||||
|
Prompt: messages,
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
Incremental: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ZhipuProvider) ChatCompleteResponse(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
requestBody := p.getChatRequestBody(request)
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
if request.Stream {
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
fullRequestURL += "/sse-invoke"
|
||||||
|
} else {
|
||||||
|
fullRequestURL += "/invoke"
|
||||||
|
}
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream {
|
||||||
|
openAIErrorWithStatusCode, usage = p.sendStreamRequest(req)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
zhipuResponse := &ZhipuResponse{}
|
||||||
|
openAIErrorWithStatusCode = p.sendRequest(req, zhipuResponse)
|
||||||
|
if openAIErrorWithStatusCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &zhipuResponse.Data.Usage
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ZhipuProvider) streamResponseZhipu2OpenAI(zhipuResponse string) *types.ChatCompletionStreamResponse {
|
||||||
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
choice.Delta.Content = zhipuResponse
|
||||||
|
response := types.ChatCompletionStreamResponse{
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "chatglm",
|
||||||
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*types.ChatCompletionStreamResponse, *types.Usage) {
|
||||||
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
choice.Delta.Content = ""
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
response := types.ChatCompletionStreamResponse{
|
||||||
|
ID: zhipuResponse.RequestId,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "chatglm",
|
||||||
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
|
}
|
||||||
|
return &response, &zhipuResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
||||||
|
// 发送请求
|
||||||
|
resp, err := common.HttpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.IsFailureStatusCode(resp) {
|
||||||
|
return p.handleErrorResp(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var usage *types.Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||||
|
return i + 2, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
metaChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
lines := strings.Split(data, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
if len(line) < 5 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if line[:5] == "data:" {
|
||||||
|
dataChan <- line[5:]
|
||||||
|
if i != len(lines)-1 {
|
||||||
|
dataChan <- "\n"
|
||||||
|
}
|
||||||
|
} else if line[:5] == "meta:" {
|
||||||
|
metaChan <- line[5:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(p.Context)
|
||||||
|
p.Context.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
response := p.streamResponseZhipu2OpenAI(data)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case data := <-metaChan:
|
||||||
|
var zhipuResponse ZhipuStreamMetaResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
usage = zhipuUsage
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
53
types/assistant.go
Normal file
53
types/assistant.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Assistant struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Instructions *string `json:"instructions,omitempty"`
|
||||||
|
Tools any `json:"tools,omitempty"`
|
||||||
|
FileIDs []string `json:"file_ids,omitempty"`
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
Instructions *string `json:"instructions,omitempty"`
|
||||||
|
Tools any `json:"tools,omitempty"`
|
||||||
|
FileIDs []string `json:"file_ids,omitempty"`
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssistantsList is a list of assistants.
|
||||||
|
type AssistantsList struct {
|
||||||
|
Assistants []Assistant `json:"data"`
|
||||||
|
LastID *string `json:"last_id"`
|
||||||
|
FirstID *string `json:"first_id"`
|
||||||
|
HasMore bool `json:"has_more"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantDeleteResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Deleted bool `json:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantFile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
AssistantID string `json:"assistant_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantFileRequest struct {
|
||||||
|
FileID string `json:"file_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssistantFilesList struct {
|
||||||
|
AssistantFiles []AssistantFile `json:"data"`
|
||||||
|
}
|
9
types/audio.go
Normal file
9
types/audio.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type SpeechAudioRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input string `json:"input"`
|
||||||
|
Voice string `json:"voice"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Speed float64 `json:"speed,omitempty"`
|
||||||
|
}
|
109
types/chat.go
Normal file
109
types/chat.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content any `json:"content"`
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
|
ToolCalls any `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ChatCompletionMessage) StringContent() string {
|
||||||
|
content, ok := m.Content.(string)
|
||||||
|
if ok {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
contentList, ok := m.Content.([]any)
|
||||||
|
if ok {
|
||||||
|
var contentStr string
|
||||||
|
for _, contentItem := range contentList {
|
||||||
|
contentMap, ok := contentItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if contentMap["type"] == "text" {
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
|
contentStr += subStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentStr
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessageImageURL struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
Detail string `json:"detail,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessagePart struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponseFormat struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
|
||||||
|
Seed *int `json:"seed,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
LogitBias any `json:"logit_bias,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
Functions any `json:"functions,omitempty"`
|
||||||
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
|
Tools any `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ChatCompletionMessage `json:"message"`
|
||||||
|
FinishReason any `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatCompletionChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamChoiceDelta struct {
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
|
ToolCalls any `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
|
||||||
|
FinishReason any `json:"finish_reason"`
|
||||||
|
ContentFilterResults any `json:"content_filter_results,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||||
|
PromptAnnotations any `json:"prompt_annotations,omitempty"`
|
||||||
|
}
|
40
types/common.go
Normal file
40
types/common.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIError struct {
|
||||||
|
Code any `json:"code,omitempty"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Param string `json:"param,omitempty"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
InnerError any `json:"innererror,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIErrorWithStatusCode struct {
|
||||||
|
OpenAIError
|
||||||
|
StatusCode int `json:"status_code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIErrorResponse struct {
|
||||||
|
Error OpenAIError `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
|
||||||
|
openAIError := OpenAIError{
|
||||||
|
Message: err.Error(),
|
||||||
|
Type: "one_api_error",
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: openAIError,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// type GeneralErrorHandling interface {
|
||||||
|
// HandleError(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode)
|
||||||
|
// }
|
36
types/completion.go
Normal file
36
types/completion.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt any `json:"prompt,omitempty"`
|
||||||
|
Suffix string `json:"suffix,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
LogProbs int `json:"logprobs,omitempty"`
|
||||||
|
Echo bool `json:"echo,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
|
BestOf int `json:"best_of,omitempty"`
|
||||||
|
LogitBias any `json:"logit_bias,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionChoice struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
LogProbs any `json:"logprobs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []CompletionChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
|
}
|
40
types/embeddings.go
Normal file
40
types/embeddings.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input any `json:"input"`
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Embedding `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r EmbeddingRequest) ParseInput() []string {
|
||||||
|
if r.Input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var input []string
|
||||||
|
switch r.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{r.Input.(string)}
|
||||||
|
case []any:
|
||||||
|
input = make([]string, 0, len(r.Input.([]any)))
|
||||||
|
for _, item := range r.Input.([]any) {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
input = append(input, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
23
types/image.go
Normal file
23
types/image.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type ImageRequest struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Style string `json:"style,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Created int64 `json:"created,omitempty"`
|
||||||
|
Data []ImageResponseDataInner `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponseDataInner struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
B64JSON string `json:"b64_json,omitempty"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user