213 lines
5.6 KiB
Go
213 lines
5.6 KiB
Go
package controller
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/songquanpeng/one-api/model"
|
|
relay "github.com/songquanpeng/one-api/relay"
|
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
|
"github.com/songquanpeng/one-api/relay/apitype"
|
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
|
"github.com/songquanpeng/one-api/relay/meta"
|
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
// https://platform.openai.com/docs/api-reference/models/list
|
|
|
|
type OpenAIModelPermission struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int `json:"created"`
|
|
AllowCreateEngine bool `json:"allow_create_engine"`
|
|
AllowSampling bool `json:"allow_sampling"`
|
|
AllowLogprobs bool `json:"allow_logprobs"`
|
|
AllowSearchIndices bool `json:"allow_search_indices"`
|
|
AllowView bool `json:"allow_view"`
|
|
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
Organization string `json:"organization"`
|
|
Group *string `json:"group"`
|
|
IsBlocking bool `json:"is_blocking"`
|
|
}
|
|
|
|
type OpenAIModels struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
Permission []OpenAIModelPermission `json:"permission"`
|
|
Root string `json:"root"`
|
|
Parent *string `json:"parent"`
|
|
}
|
|
|
|
var models []OpenAIModels
|
|
var modelsMap map[string]OpenAIModels
|
|
var channelId2Models map[int][]string
|
|
|
|
func init() {
|
|
var permission []OpenAIModelPermission
|
|
permission = append(permission, OpenAIModelPermission{
|
|
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
|
Object: "model_permission",
|
|
Created: 1626777600,
|
|
AllowCreateEngine: true,
|
|
AllowSampling: true,
|
|
AllowLogprobs: true,
|
|
AllowSearchIndices: false,
|
|
AllowView: true,
|
|
AllowFineTuning: false,
|
|
Organization: "*",
|
|
Group: nil,
|
|
IsBlocking: false,
|
|
})
|
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
|
for i := 0; i < apitype.Dummy; i++ {
|
|
if i == apitype.AIProxyLibrary {
|
|
continue
|
|
}
|
|
adaptor := relay.GetAdaptor(i)
|
|
channelName := adaptor.GetChannelName()
|
|
modelNames := adaptor.GetModelList()
|
|
for _, modelName := range modelNames {
|
|
models = append(models, OpenAIModels{
|
|
Id: modelName,
|
|
Object: "model",
|
|
Created: 1626777600,
|
|
OwnedBy: channelName,
|
|
Permission: permission,
|
|
Root: modelName,
|
|
Parent: nil,
|
|
})
|
|
}
|
|
}
|
|
for _, channelType := range openai.CompatibleChannels {
|
|
if channelType == channeltype.Azure {
|
|
continue
|
|
}
|
|
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
|
for _, modelName := range channelModelList {
|
|
models = append(models, OpenAIModels{
|
|
Id: modelName,
|
|
Object: "model",
|
|
Created: 1626777600,
|
|
OwnedBy: channelName,
|
|
Permission: permission,
|
|
Root: modelName,
|
|
Parent: nil,
|
|
})
|
|
}
|
|
}
|
|
modelsMap = make(map[string]OpenAIModels)
|
|
for _, model := range models {
|
|
modelsMap[model.Id] = model
|
|
}
|
|
channelId2Models = make(map[int][]string)
|
|
for i := 1; i < channeltype.Dummy; i++ {
|
|
adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
|
|
meta := &meta.Meta{
|
|
ChannelType: i,
|
|
}
|
|
adaptor.Init(meta)
|
|
channelId2Models[i] = adaptor.GetModelList()
|
|
}
|
|
}
|
|
|
|
func DashboardListModels(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": channelId2Models,
|
|
})
|
|
}
|
|
|
|
func ListAllModels(c *gin.Context) {
|
|
c.JSON(200, gin.H{
|
|
"object": "list",
|
|
"data": models,
|
|
})
|
|
}
|
|
|
|
func ListModels(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
var availableModels []string
|
|
if c.GetString("available_models") != "" {
|
|
availableModels = strings.Split(c.GetString("available_models"), ",")
|
|
} else {
|
|
userId := c.GetInt("id")
|
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
|
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
|
}
|
|
modelSet := make(map[string]bool)
|
|
for _, availableModel := range availableModels {
|
|
modelSet[availableModel] = true
|
|
}
|
|
availableOpenAIModels := make([]OpenAIModels, 0)
|
|
for _, model := range models {
|
|
if _, ok := modelSet[model.Id]; ok {
|
|
modelSet[model.Id] = false
|
|
availableOpenAIModels = append(availableOpenAIModels, model)
|
|
}
|
|
}
|
|
for modelName, ok := range modelSet {
|
|
if ok {
|
|
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
|
Id: modelName,
|
|
Object: "model",
|
|
Created: 1626777600,
|
|
OwnedBy: "custom",
|
|
Root: modelName,
|
|
Parent: nil,
|
|
})
|
|
}
|
|
}
|
|
c.JSON(200, gin.H{
|
|
"object": "list",
|
|
"data": availableOpenAIModels,
|
|
})
|
|
}
|
|
|
|
func RetrieveModel(c *gin.Context) {
|
|
modelId := c.Param("model")
|
|
if model, ok := modelsMap[modelId]; ok {
|
|
c.JSON(200, model)
|
|
} else {
|
|
Error := relaymodel.Error{
|
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
|
Type: "invalid_request_error",
|
|
Param: "model",
|
|
Code: "model_not_found",
|
|
}
|
|
c.JSON(200, gin.H{
|
|
"error": Error,
|
|
})
|
|
}
|
|
}
|
|
|
|
func GetUserAvailableModels(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
id := c.GetInt("id")
|
|
userGroup, err := model.CacheGetUserGroup(id)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": models,
|
|
})
|
|
return
|
|
}
|