🎨 Change the method of getting channel parameters
This commit is contained in:
parent
47b72b850f
commit
eeb867da10
@ -55,10 +55,9 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = req
|
c.Request = req
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
provider := providers.GetProvider(channel.Type, c)
|
provider := providers.GetProvider(channel, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return 0, errors.New("provider not found")
|
return 0, errors.New("provider not found")
|
||||||
}
|
}
|
||||||
@ -102,7 +101,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"balance": balance,
|
"balance": balance,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateAllChannelsBalance() error {
|
func updateAllChannelsBalance() error {
|
||||||
@ -146,7 +144,6 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyUpdateChannels(frequency int) {
|
func AutomaticallyUpdateChannels(frequency int) {
|
||||||
|
@ -29,7 +29,6 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = req
|
c.Request = req
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
|
||||||
// 创建映射
|
// 创建映射
|
||||||
channelTypeToModel := map[int]string{
|
channelTypeToModel := map[int]string{
|
||||||
common.ChannelTypePaLM: "PaLM-2",
|
common.ChannelTypePaLM: "PaLM-2",
|
||||||
@ -50,7 +49,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
}
|
}
|
||||||
request.Model = model
|
request.Model = model
|
||||||
|
|
||||||
provider := providers.GetProvider(channel.Type, c)
|
provider := providers.GetProvider(channel, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return errors.New("channel not implemented"), nil
|
return errors.New("channel not implemented"), nil
|
||||||
}
|
}
|
||||||
@ -74,7 +73,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if Usage.CompletionTokens == 0 {
|
if Usage.CompletionTokens == 0 {
|
||||||
return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil
|
return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -132,7 +131,6 @@ func TestChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var testAllChannelsLock sync.Mutex
|
var testAllChannelsLock sync.Mutex
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
@ -27,7 +28,6 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channels,
|
"data": channels,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchChannels(c *gin.Context) {
|
func SearchChannels(c *gin.Context) {
|
||||||
@ -45,7 +45,6 @@ func SearchChannels(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channels,
|
"data": channels,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetChannel(c *gin.Context) {
|
func GetChannel(c *gin.Context) {
|
||||||
@ -70,7 +69,6 @@ func GetChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channel,
|
"data": channel,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddChannel(c *gin.Context) {
|
func AddChannel(c *gin.Context) {
|
||||||
@ -106,7 +104,6 @@ func AddChannel(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteChannel(c *gin.Context) {
|
func DeleteChannel(c *gin.Context) {
|
||||||
@ -124,7 +121,6 @@ func DeleteChannel(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteDisabledChannel(c *gin.Context) {
|
func DeleteDisabledChannel(c *gin.Context) {
|
||||||
@ -141,7 +137,6 @@ func DeleteDisabledChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": rows,
|
"data": rows,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannel(c *gin.Context) {
|
func UpdateChannel(c *gin.Context) {
|
||||||
@ -167,5 +162,4 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": channel,
|
"data": channel,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -5,13 +5,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitHubOAuthResponse struct {
|
type GitHubOAuthResponse struct {
|
||||||
@ -211,7 +212,6 @@ func GitHubBind(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "bind",
|
"message": "bind",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateOAuthCode(c *gin.Context) {
|
func GenerateOAuthCode(c *gin.Context) {
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
func GetGroups(c *gin.Context) {
|
func GetGroups(c *gin.Context) {
|
||||||
groupNames := make([]string, 0)
|
groupNames := make([]string, 0)
|
||||||
for groupName, _ := range common.GroupRatio {
|
for groupName := range common.GroupRatio {
|
||||||
groupNames = append(groupNames, groupName)
|
groupNames = append(groupNames, groupName)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllLogs(c *gin.Context) {
|
func GetAllLogs(c *gin.Context) {
|
||||||
@ -33,7 +34,6 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(c *gin.Context) {
|
func GetUserLogs(c *gin.Context) {
|
||||||
@ -60,7 +60,6 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(c *gin.Context) {
|
func SearchAllLogs(c *gin.Context) {
|
||||||
@ -78,7 +77,6 @@ func SearchAllLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUserLogs(c *gin.Context) {
|
func SearchUserLogs(c *gin.Context) {
|
||||||
@ -97,7 +95,6 @@ func SearchUserLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsStat(c *gin.Context) {
|
func GetLogsStat(c *gin.Context) {
|
||||||
@ -118,7 +115,6 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsSelfStat(c *gin.Context) {
|
func GetLogsSelfStat(c *gin.Context) {
|
||||||
@ -139,7 +135,6 @@ func GetLogsSelfStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteHistoryLogs(c *gin.Context) {
|
func DeleteHistoryLogs(c *gin.Context) {
|
||||||
@ -164,5 +159,4 @@ func DeleteHistoryLogs(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": count,
|
"data": count,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,6 @@ func GetStatus(c *gin.Context) {
|
|||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNotice(c *gin.Context) {
|
func GetNotice(c *gin.Context) {
|
||||||
@ -46,7 +45,6 @@ func GetNotice(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["Notice"],
|
"data": common.OptionMap["Notice"],
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAbout(c *gin.Context) {
|
func GetAbout(c *gin.Context) {
|
||||||
@ -57,7 +55,6 @@ func GetAbout(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["About"],
|
"data": common.OptionMap["About"],
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHomePageContent(c *gin.Context) {
|
func GetHomePageContent(c *gin.Context) {
|
||||||
@ -68,7 +65,6 @@ func GetHomePageContent(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["HomePageContent"],
|
"data": common.OptionMap["HomePageContent"],
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendEmailVerification(c *gin.Context) {
|
func SendEmailVerification(c *gin.Context) {
|
||||||
@ -121,7 +117,6 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendPasswordResetEmail(c *gin.Context) {
|
func SendPasswordResetEmail(c *gin.Context) {
|
||||||
@ -160,7 +155,6 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PasswordResetRequest struct {
|
type PasswordResetRequest struct {
|
||||||
@ -200,5 +194,4 @@ func ResetPassword(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": password,
|
"data": password,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func RelayChat(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeChatCompletions)
|
provider, pass := getProvider(c, channel, common.RelayModeChatCompletions)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ func RelayCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeCompletions)
|
provider, pass := getProvider(c, channel, common.RelayModeCompletions)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func RelayEmbeddings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeEmbeddings)
|
provider, pass := getProvider(c, channel, common.RelayModeEmbeddings)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func RelayImageEdits(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesEdits)
|
provider, pass := getProvider(c, channel, common.RelayModeImagesEdits)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@ func RelayImageGenerations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesGenerations)
|
provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ func RelayImageVariations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeImagesVariations)
|
provider, pass := getProvider(c, channel, common.RelayModeImagesVariations)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func RelayModerations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeModerations)
|
provider, pass := getProvider(c, channel, common.RelayModeModerations)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ func RelaySpeech(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioSpeech)
|
provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ func RelayTranscriptions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranscription)
|
provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ func RelayTranslations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranslation)
|
provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation)
|
||||||
if pass {
|
if pass {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,6 @@ func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pas
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setChannelToContext(c, channel)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +83,8 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool
|
|||||||
return channel, false
|
return channel, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.ProviderInterface, bool) {
|
func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) {
|
||||||
provider := providers.GetProvider(channelType, c)
|
provider := providers.GetProvider(channel, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
||||||
return nil, true
|
return nil, true
|
||||||
@ -99,27 +98,6 @@ func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.
|
|||||||
return provider, false
|
return provider, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func setChannelToContext(c *gin.Context, channel *model.Channel) {
|
|
||||||
// c.Set("channel", channel.Type)
|
|
||||||
c.Set("channel_id", channel.Id)
|
|
||||||
c.Set("channel_name", channel.Name)
|
|
||||||
c.Set("api_key", channel.Key)
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllTokens(c *gin.Context) {
|
func GetAllTokens(c *gin.Context) {
|
||||||
@ -27,7 +28,6 @@ func GetAllTokens(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": tokens,
|
"data": tokens,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchTokens(c *gin.Context) {
|
func SearchTokens(c *gin.Context) {
|
||||||
@ -46,7 +46,6 @@ func SearchTokens(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": tokens,
|
"data": tokens,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetToken(c *gin.Context) {
|
func GetToken(c *gin.Context) {
|
||||||
@ -72,7 +71,6 @@ func GetToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": token,
|
"data": token,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTokenStatus(c *gin.Context) {
|
func GetTokenStatus(c *gin.Context) {
|
||||||
@ -138,7 +136,6 @@ func AddToken(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteToken(c *gin.Context) {
|
func DeleteToken(c *gin.Context) {
|
||||||
@ -156,7 +153,6 @@ func DeleteToken(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateToken(c *gin.Context) {
|
func UpdateToken(c *gin.Context) {
|
||||||
@ -224,5 +220,4 @@ func UpdateToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": cleanToken,
|
"data": cleanToken,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -174,7 +174,6 @@ func Register(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(c *gin.Context) {
|
func GetAllUsers(c *gin.Context) {
|
||||||
@ -195,7 +194,6 @@ func GetAllUsers(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": users,
|
"data": users,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUsers(c *gin.Context) {
|
func SearchUsers(c *gin.Context) {
|
||||||
@ -213,7 +211,6 @@ func SearchUsers(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": users,
|
"data": users,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(c *gin.Context) {
|
func GetUser(c *gin.Context) {
|
||||||
@ -246,7 +243,6 @@ func GetUser(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": user,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserDashboard(c *gin.Context) {
|
func GetUserDashboard(c *gin.Context) {
|
||||||
@ -306,7 +302,6 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user.AccessToken,
|
"data": user.AccessToken,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAffCode(c *gin.Context) {
|
func GetAffCode(c *gin.Context) {
|
||||||
@ -334,7 +329,6 @@ func GetAffCode(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user.AffCode,
|
"data": user.AffCode,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSelf(c *gin.Context) {
|
func GetSelf(c *gin.Context) {
|
||||||
@ -352,7 +346,6 @@ func GetSelf(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": user,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(c *gin.Context) {
|
func UpdateUser(c *gin.Context) {
|
||||||
@ -416,7 +409,6 @@ func UpdateUser(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSelf(c *gin.Context) {
|
func UpdateSelf(c *gin.Context) {
|
||||||
@ -463,7 +455,6 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(c *gin.Context) {
|
func DeleteUser(c *gin.Context) {
|
||||||
@ -525,7 +516,6 @@ func DeleteSelf(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUser(c *gin.Context) {
|
func CreateUser(c *gin.Context) {
|
||||||
@ -574,7 +564,6 @@ func CreateUser(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManageRequest struct {
|
type ManageRequest struct {
|
||||||
@ -691,7 +680,6 @@ func ManageUser(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": clearUser,
|
"data": clearUser,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmailBind(c *gin.Context) {
|
func EmailBind(c *gin.Context) {
|
||||||
@ -733,7 +721,6 @@ func EmailBind(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type topUpRequest struct {
|
type topUpRequest struct {
|
||||||
@ -764,5 +751,4 @@ func TopUp(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": quota,
|
"data": quota,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wechatLoginResponse struct {
|
type wechatLoginResponse struct {
|
||||||
@ -160,5 +161,4 @@ func WeChatBind(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
@ -32,9 +32,9 @@ type AliProvider struct {
|
|||||||
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||||
if p.Context.GetString("plugin") != "" {
|
if p.Channel.Other != "" {
|
||||||
headers["X-DashScope-Plugin"] = p.Context.GetString("plugin")
|
headers["X-DashScope-Plugin"] = p.Channel.Other
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
@ -27,7 +27,7 @@ type AzureSpeechProvider struct {
|
|||||||
// 获取请求头
|
// 获取请求头
|
||||||
func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
headers["Ocp-Apim-Subscription-Key"] = p.Context.GetString("api_key")
|
headers["Ocp-Apim-Subscription-Key"] = p.Channel.Key
|
||||||
headers["Content-Type"] = "application/ssml+xml"
|
headers["Content-Type"] = "application/ssml+xml"
|
||||||
headers["User-Agent"] = "OneAPI"
|
headers["User-Agent"] = "OneAPI"
|
||||||
// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3"
|
// headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3"
|
||||||
|
@ -63,7 +63,7 @@ func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
||||||
apiKey := p.Context.GetString("api_key")
|
apiKey := p.Channel.Key
|
||||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||||
var accessToken BaiduAccessToken
|
var accessToken BaiduAccessToken
|
||||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -28,17 +29,22 @@ type BaseProvider struct {
|
|||||||
ImagesVariations string
|
ImagesVariations string
|
||||||
Proxy string
|
Proxy string
|
||||||
Context *gin.Context
|
Context *gin.Context
|
||||||
|
Channel *model.Channel
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取基础URL
|
// 获取基础URL
|
||||||
func (p *BaseProvider) GetBaseURL() string {
|
func (p *BaseProvider) GetBaseURL() string {
|
||||||
if p.Context.GetString("base_url") != "" {
|
if p.Channel.GetBaseURL() != "" {
|
||||||
return p.Context.GetString("base_url")
|
return p.Channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.BaseURL
|
return p.BaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BaseProvider) SetChannel(channel *model.Channel) {
|
||||||
|
p.Channel = channel
|
||||||
|
}
|
||||||
|
|
||||||
// 获取完整请求URL
|
// 获取完整请求URL
|
||||||
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
@ -12,6 +12,7 @@ type ProviderInterface interface {
|
|||||||
GetFullRequestURL(requestURL string, modelName string) string
|
GetFullRequestURL(requestURL string, modelName string) string
|
||||||
GetRequestHeaders() (headers map[string]string)
|
GetRequestHeaders() (headers map[string]string)
|
||||||
SupportAPI(relayMode int) bool
|
SupportAPI(relayMode int) bool
|
||||||
|
SetChannel(channel *model.Channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 完成接口
|
// 完成接口
|
||||||
|
@ -28,7 +28,7 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
|
|
||||||
headers["x-api-key"] = p.Context.GetString("api_key")
|
headers["x-api-key"] = p.Channel.Key
|
||||||
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
anthropicVersion := p.Context.Request.Header.Get("anthropic-version")
|
||||||
if anthropicVersion == "" {
|
if anthropicVersion == "" {
|
||||||
anthropicVersion = "2023-06-01"
|
anthropicVersion = "2023-06-01"
|
||||||
|
@ -28,8 +28,8 @@ type GeminiProvider struct {
|
|||||||
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
version := "v1"
|
version := "v1"
|
||||||
if p.Context.GetString("api_version") != "" {
|
if p.Channel.Other != "" {
|
||||||
version = p.Context.GetString("api_version")
|
version = p.Channel.Other
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
|
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
|
||||||
@ -40,7 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
|
|||||||
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
headers["x-goog-api-key"] = p.Context.GetString("api_key")
|
headers["x-goog-api-key"] = p.Channel.Key
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
|||||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
|
||||||
if p.IsAzure {
|
if p.IsAzure {
|
||||||
apiVersion := p.Context.GetString("api_version")
|
apiVersion := p.Channel.Other
|
||||||
if modelName == "dall-e-2" {
|
if modelName == "dall-e-2" {
|
||||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||||
// 已经没有dall-e-2了,所以暂时写死
|
// 已经没有dall-e-2了,所以暂时写死
|
||||||
@ -85,9 +85,9 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
if p.IsAzure {
|
if p.IsAzure {
|
||||||
headers["api-key"] = p.Context.GetString("api_key")
|
headers["api-key"] = p.Channel.Key
|
||||||
} else {
|
} else {
|
||||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
@ -29,7 +29,7 @@ type PalmProvider struct {
|
|||||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||||
headers = make(map[string]string)
|
headers = make(map[string]string)
|
||||||
p.CommonRequestHeaders(headers)
|
p.CommonRequestHeaders(headers)
|
||||||
headers["x-goog-api-key"] = p.Context.GetString("api_key")
|
headers["x-goog-api-key"] = p.Channel.Key
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
"one-api/providers/aigc2d"
|
"one-api/providers/aigc2d"
|
||||||
"one-api/providers/aiproxy"
|
"one-api/providers/aiproxy"
|
||||||
"one-api/providers/ali"
|
"one-api/providers/ali"
|
||||||
@ -55,19 +56,23 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取供应商
|
// 获取供应商
|
||||||
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
|
func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface {
|
||||||
factory, ok := providerFactories[channelType]
|
factory, ok := providerFactories[channel.Type]
|
||||||
|
var provider base.ProviderInterface
|
||||||
if !ok {
|
if !ok {
|
||||||
// 处理未找到的供应商工厂
|
// 处理未找到的供应商工厂
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if c.GetString("base_url") != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
if baseURL != "" {
|
if baseURL == "" {
|
||||||
return openai.CreateOpenAIProvider(c, baseURL)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
provider = openai.CreateOpenAIProvider(c, baseURL)
|
||||||
}
|
}
|
||||||
return factory.Create(c)
|
provider = factory.Create(c)
|
||||||
|
provider.SetChannel(channel)
|
||||||
|
|
||||||
|
return provider
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secret
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
func (p *TencentProvider) getTencentSign(req TencentChatRequest) string {
|
||||||
apiKey := p.Context.GetString("api_key")
|
apiKey := p.Channel.Key
|
||||||
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
appId, secretId, secretKey, err := p.parseTencentConfig(apiKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
|
@ -42,7 +42,7 @@ func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
|
|||||||
|
|
||||||
// 获取完整请求 URL
|
// 获取完整请求 URL
|
||||||
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||||
splits := strings.Split(p.Context.GetString("api_key"), "|")
|
splits := strings.Split(p.Channel.Key, "|")
|
||||||
if len(splits) != 3 {
|
if len(splits) != 3 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -58,7 +58,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri
|
|||||||
query := p.Context.Request.URL.Query()
|
query := p.Context.Request.URL.Query()
|
||||||
apiVersion := query.Get("api-version")
|
apiVersion := query.Get("api-version")
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = p.Context.GetString("api_version")
|
apiVersion = p.Channel.Key
|
||||||
}
|
}
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = "v1.1"
|
apiVersion = "v1.1"
|
||||||
|
@ -49,7 +49,7 @@ func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ZhipuProvider) getZhipuToken() string {
|
func (p *ZhipuProvider) getZhipuToken() string {
|
||||||
apikey := p.Context.GetString("api_key")
|
apikey := p.Channel.Key
|
||||||
data, ok := zhipuTokens.Load(apikey)
|
data, ok := zhipuTokens.Load(apikey)
|
||||||
if ok {
|
if ok {
|
||||||
tokenData := data.(zhipuTokenData)
|
tokenData := data.(zhipuTokenData)
|
||||||
|
Loading…
Reference in New Issue
Block a user