🔖 chore: Rename relay/util to relay/relay_util package and add utils package
This commit is contained in:
parent
853f2681f4
commit
79524108a3
@ -3,13 +3,13 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ExportPrices() {
|
func ExportPrices() {
|
||||||
prices := util.GetPricesList("default")
|
prices := relay_util.GetPricesList("default")
|
||||||
|
|
||||||
if len(prices) == 0 {
|
if len(prices) == 0 {
|
||||||
common.SysError("No prices found")
|
common.SysError("No prices found")
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"one-api/cli"
|
"one-api/cli"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
@ -22,11 +23,11 @@ func InitConf() {
|
|||||||
|
|
||||||
common.IsMasterNode = viper.GetString("node_type") != "slave"
|
common.IsMasterNode = viper.GetString("node_type") != "slave"
|
||||||
common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second
|
common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second
|
||||||
common.SessionSecret = common.GetOrDefault("session_secret", common.SessionSecret)
|
common.SessionSecret = utils.GetOrDefault("session_secret", common.SessionSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setConfigFile() {
|
func setConfigFile() {
|
||||||
if !common.IsFileExist(*cli.Config) {
|
if !utils.IsFileExist(*cli.Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,15 @@ import (
|
|||||||
_ "golang.org/x/image/webp"
|
_ "golang.org/x/image/webp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// var ImageHttpClients = &http.Client{
|
||||||
|
// Transport: &http.Transport{
|
||||||
|
// DialContext: requester.Socks5ProxyFunc,
|
||||||
|
// Proxy: requester.ProxyFunc,
|
||||||
|
// },
|
||||||
|
// //
|
||||||
|
// // // Timeout: 30 * time.Second,
|
||||||
|
// }
|
||||||
|
|
||||||
func IsImageUrl(url string) (bool, error) {
|
func IsImageUrl(url string) (bool, error) {
|
||||||
resp, err := http.Head(url)
|
resp, err := http.Head(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"one-api/common/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
@ -62,13 +64,13 @@ func getLogDir() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
logDir, err = filepath.Abs(viper.GetString("log_dir"))
|
logDir, err = filepath.Abs(logDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if !IsFileExist(logDir) {
|
if !utils.IsFileExist(logDir) {
|
||||||
err = os.Mkdir(logDir, 0777)
|
err = os.Mkdir(logDir, 0777)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -1,87 +1,24 @@
|
|||||||
package requester
|
package requester
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"one-api/common/utils"
|
||||||
"one-api/common"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextKey string
|
|
||||||
|
|
||||||
const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
|
|
||||||
const ProxySock5AddrKey ContextKey = "proxySock5Addr"
|
|
||||||
|
|
||||||
func proxyFunc(req *http.Request) (*url.URL, error) {
|
|
||||||
proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
|
|
||||||
if proxyAddr == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(proxyAddr.(string))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing proxy address: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch proxyURL.Scheme {
|
|
||||||
case "http", "https":
|
|
||||||
return proxyURL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
// 设置TCP超时
|
|
||||||
dialer := &net.Dialer{
|
|
||||||
Timeout: time.Duration(common.GetOrDefault("connect_timeout", 5)) * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从上下文中获取代理地址
|
|
||||||
proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string)
|
|
||||||
if !ok {
|
|
||||||
return dialer.DialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(proxyAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing proxy address: %w", err)
|
|
||||||
}
|
|
||||||
var auth *proxy.Auth = nil
|
|
||||||
password, isSetPassword := proxyURL.User.Password()
|
|
||||||
if isSetPassword {
|
|
||||||
auth = &proxy.Auth{
|
|
||||||
User: proxyURL.User.Username(),
|
|
||||||
Password: password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
proxyDialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating socks5 dialer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return proxyDialer.Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
var HTTPClient *http.Client
|
var HTTPClient *http.Client
|
||||||
|
|
||||||
func InitHttpClient() {
|
func InitHttpClient() {
|
||||||
trans := &http.Transport{
|
trans := &http.Transport{
|
||||||
DialContext: socks5ProxyFunc,
|
DialContext: utils.Socks5ProxyFunc,
|
||||||
Proxy: proxyFunc,
|
Proxy: utils.ProxyFunc,
|
||||||
}
|
}
|
||||||
|
|
||||||
HTTPClient = &http.Client{
|
HTTPClient = &http.Client{
|
||||||
Transport: trans,
|
Transport: trans,
|
||||||
}
|
}
|
||||||
|
|
||||||
relayTimeout := common.GetOrDefault("relay_timeout", 600)
|
relayTimeout := utils.GetOrDefault("relay_timeout", 600)
|
||||||
if relayTimeout != 0 {
|
if relayTimeout != 0 {
|
||||||
HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second
|
HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -52,18 +53,7 @@ type requestOptions struct {
|
|||||||
type requestOption func(*requestOptions)
|
type requestOption func(*requestOptions)
|
||||||
|
|
||||||
func (r *HTTPRequester) setProxy() context.Context {
|
func (r *HTTPRequester) setProxy() context.Context {
|
||||||
if r.proxyAddr == "" {
|
return utils.SetProxy(r.Context, r.proxyAddr)
|
||||||
return r.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
|
||||||
if strings.HasPrefix(r.proxyAddr, "socks5://") {
|
|
||||||
return context.WithValue(r.Context, ProxySock5AddrKey, r.proxyAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 否则使用 http 代理
|
|
||||||
return context.WithValue(r.Context, ProxyHTTPAddrKey, r.proxyAddr)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建请求
|
// 创建请求
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
|
|
||||||
func GetWSClient(proxyAddr string) *websocket.Dialer {
|
func GetWSClient(proxyAddr string) *websocket.Dialer {
|
||||||
dialer := &websocket.Dialer{
|
dialer := &websocket.Dialer{
|
||||||
HandshakeTimeout: time.Duration(common.GetOrDefault("connect_timeout", 5)) * time.Second,
|
HandshakeTimeout: time.Duration(utils.GetOrDefault("connect_timeout", 5)) * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
if proxyAddr != "" {
|
if proxyAddr != "" {
|
||||||
@ -38,20 +39,16 @@ func setWSProxy(dialer *websocket.Dialer, proxyAddr string) error {
|
|||||||
case "http", "https":
|
case "http", "https":
|
||||||
dialer.Proxy = http.ProxyURL(proxyURL)
|
dialer.Proxy = http.ProxyURL(proxyURL)
|
||||||
case "socks5":
|
case "socks5":
|
||||||
var auth *proxy.Auth = nil
|
proxyDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||||
password, isSetPassword := proxyURL.User.Password()
|
|
||||||
if isSetPassword {
|
|
||||||
auth = &proxy.Auth{
|
|
||||||
User: proxyURL.User.Username(),
|
|
||||||
Password: password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
socks5Proxy, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating socks5 dialer: %w", err)
|
return fmt.Errorf("error creating proxy dialer: %w", err)
|
||||||
}
|
}
|
||||||
|
originalNetDial := dialer.NetDial
|
||||||
dialer.NetDial = func(network, addr string) (net.Conn, error) {
|
dialer.NetDial = func(network, addr string) (net.Conn, error) {
|
||||||
return socks5Proxy.Dial(network, addr)
|
if originalNetDial != nil {
|
||||||
|
return originalNetDial(network, addr)
|
||||||
|
}
|
||||||
|
return proxyDialer.Dial(network, addr)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
return fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||||
|
@ -3,6 +3,7 @@ package stmp
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/wneessen/go-mail"
|
"github.com/wneessen/go-mail"
|
||||||
@ -67,7 +68,7 @@ func (s *StmpConfig) Send(to, subject, body string) error {
|
|||||||
|
|
||||||
func (s *StmpConfig) getReferences() string {
|
func (s *StmpConfig) getReferences() string {
|
||||||
froms := strings.Split(s.From, "@")
|
froms := strings.Split(s.From, "@")
|
||||||
return fmt.Sprintf("<%s.%s@%s>", froms[0], common.GetUUID(), froms[1])
|
return fmt.Sprintf("<%s.%s@%s>", froms[0], utils.GetUUID(), froms[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StmpConfig) Render(to, subject, content string) error {
|
func (s *StmpConfig) Render(to, subject, content string) error {
|
||||||
|
@ -5,7 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common/utils"
|
||||||
|
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/common/storage/drives"
|
"one-api/common/storage/drives"
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ func TestSMMSUpload(t *testing.T) {
|
|||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := smUpload.Upload(image, common.GetUUID()+".png")
|
url, err := smUpload.Upload(image, utils.GetUUID()+".png")
|
||||||
fmt.Println(url)
|
fmt.Println(url)
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@ -48,7 +49,7 @@ func TestImgurUpload(t *testing.T) {
|
|||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := imgurUpload.Upload(image, common.GetUUID()+".png")
|
url, err := imgurUpload.Upload(image, utils.GetUUID()+".png")
|
||||||
fmt.Println(url)
|
fmt.Println(url)
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
@ -2,6 +2,7 @@ package telegram
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/PaulSonOfLars/gotgbot/v2"
|
"github.com/PaulSonOfLars/gotgbot/v2"
|
||||||
@ -15,7 +16,7 @@ func commandAffStart(b *gotgbot.Bot, ctx *ext.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if user.AffCode == "" {
|
if user.AffCode == "" {
|
||||||
user.AffCode = common.GetRandomString(4)
|
user.AffCode = utils.GetRandomString(4)
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
ctx.EffectiveMessage.Reply(b, "系统错误,请稍后再试", nil)
|
ctx.EffectiveMessage.Reply(b, "系统错误,请稍后再试", nil)
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package telegram
|
package telegram
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -243,22 +245,16 @@ func getHttpClient() (httpClient *http.Client) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
case "socks5":
|
case "socks5":
|
||||||
var auth *proxy.Auth = nil
|
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||||
password, isSetPassword := proxyURL.User.Password()
|
|
||||||
if isSetPassword {
|
|
||||||
auth = &proxy.Auth{
|
|
||||||
User: proxyURL.User.Username(),
|
|
||||||
Password: password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error())
|
common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
httpClient = &http.Client{
|
httpClient = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Dial: dialer.Dial,
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -119,7 +119,7 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in
|
|||||||
imageTokens, err := countImageTokens(url, detail)
|
imageTokens, err := countImageTokens(url, detail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//Due to the excessive length of the error information, only extract and record the most critical part.
|
//Due to the excessive length of the error information, only extract and record the most critical part.
|
||||||
SysError("error counting image tokens: " + err.Error())
|
SysError("error counting image tokens: " + err.Error())
|
||||||
} else {
|
} else {
|
||||||
tokenNum += imageTokens
|
tokenNum += imageTokens
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package common
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -109,13 +109,13 @@ func Seconds2Time(num int) (time string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Interface2String(inter interface{}) string {
|
func Interface2String(inter interface{}) string {
|
||||||
switch inter.(type) {
|
switch inter := inter.(type) {
|
||||||
case string:
|
case string:
|
||||||
return inter.(string)
|
return inter
|
||||||
case int:
|
case int:
|
||||||
return fmt.Sprintf("%d", inter.(int))
|
return fmt.Sprintf("%d", inter)
|
||||||
case float64:
|
case float64:
|
||||||
return fmt.Sprintf("%f", inter.(float64))
|
return fmt.Sprintf("%f", inter)
|
||||||
}
|
}
|
||||||
return "Not Implemented"
|
return "Not Implemented"
|
||||||
}
|
}
|
||||||
@ -140,12 +140,7 @@ func GetUUID() string {
|
|||||||
|
|
||||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
|
||||||
func init() {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
func GenerateKey() string {
|
func GenerateKey() string {
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
key := make([]byte, 48)
|
key := make([]byte, 48)
|
||||||
for i := 0; i < 16; i++ {
|
for i := 0; i < 16; i++ {
|
||||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||||
@ -162,7 +157,6 @@ func GenerateKey() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomString(length int) string {
|
func GetRandomString(length int) string {
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
key := make([]byte, length)
|
key := make([]byte, length)
|
||||||
for i := 0; i < length; i++ {
|
for i := 0; i < length; i++ {
|
||||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
key[i] = keyChars[rand.Intn(len(keyChars))]
|
77
common/utils/proxy.go
Normal file
77
common/utils/proxy.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
|
const ProxyHTTPAddrKey ContextKey = "proxyHttpAddr"
|
||||||
|
const ProxySock5AddrKey ContextKey = "proxySock5Addr"
|
||||||
|
|
||||||
|
func ProxyFunc(req *http.Request) (*url.URL, error) {
|
||||||
|
proxyAddr := req.Context().Value(ProxyHTTPAddrKey)
|
||||||
|
if proxyAddr == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(proxyAddr.(string))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing proxy address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch proxyURL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
return proxyURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: time.Duration(GetOrDefault("connect_timeout", 5)) * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyAddr, ok := ctx.Value(ProxySock5AddrKey).(string)
|
||||||
|
if !ok {
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(proxyAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing proxy address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyDialer, err := proxy.FromURL(proxyURL, dialer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating proxy dialer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyDialer.Dial(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetProxy(ctx context.Context, proxyAddr string) context.Context {
|
||||||
|
if proxyAddr == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
key := ProxyHTTPAddrKey
|
||||||
|
|
||||||
|
// 如果是以 socks5:// 开头的地址,那么使用 socks5 代理
|
||||||
|
if strings.HasPrefix(proxyAddr, "socks5://") {
|
||||||
|
key = ProxySock5AddrKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则使用 http 代理
|
||||||
|
return context.WithValue(ctx, key, proxyAddr)
|
||||||
|
}
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/notify"
|
"one-api/common/notify"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers"
|
"one-api/providers"
|
||||||
providers_base "one-api/providers/base"
|
providers_base "one-api/providers/base"
|
||||||
@ -153,7 +154,7 @@ func testAllChannels(isNotify bool) error {
|
|||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
|
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", common.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr())
|
sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr())
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiErr := testChannel(channel, "")
|
err, openaiErr := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
@ -161,7 +162,7 @@ func testAllChannels(isNotify bool) error {
|
|||||||
// 通道为禁用状态,并且还是请求错误 或者 响应时间超过阈值 直接跳过,也不需要更新响应时间。
|
// 通道为禁用状态,并且还是请求错误 或者 响应时间超过阈值 直接跳过,也不需要更新响应时间。
|
||||||
if !isChannelEnabled {
|
if !isChannelEnabled {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", common.EscapeMarkdownText(err.Error()))
|
sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", utils.EscapeMarkdownText(err.Error()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
@ -187,13 +188,13 @@ func testAllChannels(isNotify bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ShouldDisableChannel(openaiErr, -1) {
|
if ShouldDisableChannel(openaiErr, -1) {
|
||||||
sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", common.EscapeMarkdownText(err.Error()))
|
sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", utils.EscapeMarkdownText(err.Error()))
|
||||||
DisableChannel(channel.Id, channel.Name, err.Error(), false)
|
DisableChannel(channel.Id, channel.Name, err.Error(), false)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", common.EscapeMarkdownText(err.Error()))
|
sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", utils.EscapeMarkdownText(err.Error()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -64,7 +65,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.CreatedTime = common.GetTimestamp()
|
channel.CreatedTime = utils.GetTimestamp()
|
||||||
keys := strings.Split(channel.Key, "\n")
|
keys := strings.Split(channel.Key, "\n")
|
||||||
channels := make([]model.Channel, 0, len(keys))
|
channels := make([]model.Channel, 0, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@ -216,7 +217,7 @@ func GitHubBind(c *gin.Context) {
|
|||||||
|
|
||||||
func GenerateOAuthCode(c *gin.Context) {
|
func GenerateOAuthCode(c *gin.Context) {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
state := common.GetRandomString(12)
|
state := utils.GetRandomString(12)
|
||||||
session.Set("oauth_state", state)
|
session.Set("oauth_state", state)
|
||||||
err := session.Save()
|
err := session.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ func GetOptions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
options = append(options, &model.Option{
|
options = append(options, &model.Option{
|
||||||
Key: k,
|
Key: k,
|
||||||
Value: common.Interface2String(v),
|
Value: utils.Interface2String(v),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
common.OptionMapRWMutex.Unlock()
|
common.OptionMapRWMutex.Unlock()
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -14,7 +14,7 @@ import (
|
|||||||
func GetPricesList(c *gin.Context) {
|
func GetPricesList(c *gin.Context) {
|
||||||
pricesType := c.DefaultQuery("type", "db")
|
pricesType := c.DefaultQuery("type", "db")
|
||||||
|
|
||||||
prices := util.GetPricesList(pricesType)
|
prices := relay_util.GetPricesList(pricesType)
|
||||||
|
|
||||||
if len(prices) == 0 {
|
if len(prices) == 0 {
|
||||||
common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found"))
|
common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found"))
|
||||||
@ -33,7 +33,7 @@ func GetPricesList(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAllModelList(c *gin.Context) {
|
func GetAllModelList(c *gin.Context) {
|
||||||
prices := util.PricingInstance.GetAllPrices()
|
prices := relay_util.PricingInstance.GetAllPrices()
|
||||||
channelModel := model.ChannelGroup.Rule
|
channelModel := model.ChannelGroup.Rule
|
||||||
|
|
||||||
modelsMap := make(map[string]bool)
|
modelsMap := make(map[string]bool)
|
||||||
@ -68,7 +68,7 @@ func AddPrice(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := util.PricingInstance.AddPrice(&price); err != nil {
|
if err := relay_util.PricingInstance.AddPrice(&price); err != nil {
|
||||||
common.APIRespondWithError(c, http.StatusOK, err)
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ func UpdatePrice(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := util.PricingInstance.UpdatePrice(modelName, &price); err != nil {
|
if err := relay_util.PricingInstance.UpdatePrice(modelName, &price); err != nil {
|
||||||
common.APIRespondWithError(c, http.StatusOK, err)
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -114,7 +114,7 @@ func DeletePrice(c *gin.Context) {
|
|||||||
modelName = modelName[1:]
|
modelName = modelName[1:]
|
||||||
modelName, _ = url.PathUnescape(modelName)
|
modelName, _ = url.PathUnescape(modelName)
|
||||||
|
|
||||||
if err := util.PricingInstance.DeletePrice(modelName); err != nil {
|
if err := relay_util.PricingInstance.DeletePrice(modelName); err != nil {
|
||||||
common.APIRespondWithError(c, http.StatusOK, err)
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -127,7 +127,7 @@ func DeletePrice(c *gin.Context) {
|
|||||||
|
|
||||||
type PriceBatchRequest struct {
|
type PriceBatchRequest struct {
|
||||||
OriginalModels []string `json:"original_models"`
|
OriginalModels []string `json:"original_models"`
|
||||||
util.BatchPrices
|
relay_util.BatchPrices
|
||||||
}
|
}
|
||||||
|
|
||||||
func BatchSetPrices(c *gin.Context) {
|
func BatchSetPrices(c *gin.Context) {
|
||||||
@ -137,7 +137,7 @@ func BatchSetPrices(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil {
|
if err := relay_util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil {
|
||||||
common.APIRespondWithError(c, http.StatusOK, err)
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -159,7 +159,7 @@ func BatchDeletePrices(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil {
|
if err := relay_util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil {
|
||||||
common.APIRespondWithError(c, http.StatusOK, err)
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -184,7 +184,7 @@ func SyncPricing(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := util.PricingInstance.SyncPricing(prices, overwrite == "true")
|
err := relay_util.PricingInstance.SyncPricing(prices, overwrite == "true")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.APIRespondWithError(c, http.StatusOK, err)
|
common.APIRespondWithError(c, http.StatusOK, err)
|
||||||
return
|
return
|
||||||
|
@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -85,12 +86,12 @@ func AddRedemption(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var keys []string
|
var keys []string
|
||||||
for i := 0; i < redemption.Count; i++ {
|
for i := 0; i < redemption.Count; i++ {
|
||||||
key := common.GetUUID()
|
key := utils.GetUUID()
|
||||||
cleanRedemption := model.Redemption{
|
cleanRedemption := model.Redemption{
|
||||||
UserId: c.GetInt("id"),
|
UserId: c.GetInt("id"),
|
||||||
Name: redemption.Name,
|
Name: redemption.Name,
|
||||||
Key: key,
|
Key: key,
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: utils.GetTimestamp(),
|
||||||
Quota: redemption.Quota,
|
Quota: redemption.Quota,
|
||||||
}
|
}
|
||||||
err = cleanRedemption.Insert()
|
err = cleanRedemption.Insert()
|
||||||
|
@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -62,9 +63,9 @@ func GetPlaygroundToken(c *gin.Context) {
|
|||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Name: tokenName,
|
Name: tokenName,
|
||||||
Key: common.GenerateKey(),
|
Key: utils.GenerateKey(),
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: utils.GetTimestamp(),
|
||||||
AccessedTime: common.GetTimestamp(),
|
AccessedTime: utils.GetTimestamp(),
|
||||||
ExpiredTime: 0,
|
ExpiredTime: 0,
|
||||||
RemainQuota: 0,
|
RemainQuota: 0,
|
||||||
UnlimitedQuota: true,
|
UnlimitedQuota: true,
|
||||||
@ -132,9 +133,9 @@ func AddToken(c *gin.Context) {
|
|||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
UserId: c.GetInt("id"),
|
UserId: c.GetInt("id"),
|
||||||
Name: token.Name,
|
Name: token.Name,
|
||||||
Key: common.GenerateKey(),
|
Key: utils.GenerateKey(),
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: utils.GetTimestamp(),
|
||||||
AccessedTime: common.GetTimestamp(),
|
AccessedTime: utils.GetTimestamp(),
|
||||||
ExpiredTime: token.ExpiredTime,
|
ExpiredTime: token.ExpiredTime,
|
||||||
RemainQuota: token.RemainQuota,
|
RemainQuota: token.RemainQuota,
|
||||||
UnlimitedQuota: token.UnlimitedQuota,
|
UnlimitedQuota: token.UnlimitedQuota,
|
||||||
@ -199,7 +200,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token.Status == common.TokenStatusEnabled {
|
if token.Status == common.TokenStatusEnabled {
|
||||||
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@ -261,7 +262,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.AccessToken = common.GetUUID()
|
user.AccessToken = utils.GetUUID()
|
||||||
|
|
||||||
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -297,7 +298,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.AffCode == "" {
|
if user.AffCode == "" {
|
||||||
user.AffCode = common.GetRandomString(4)
|
user.AffCode = utils.GetRandomString(4)
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
6
main.go
6
main.go
@ -13,7 +13,7 @@ import (
|
|||||||
"one-api/cron"
|
"one-api/cron"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ func main() {
|
|||||||
common.InitRedisClient()
|
common.InitRedisClient()
|
||||||
// Initialize options
|
// Initialize options
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
util.NewPricing()
|
relay_util.NewPricing()
|
||||||
initMemoryCache()
|
initMemoryCache()
|
||||||
initSync()
|
initSync()
|
||||||
|
|
||||||
@ -112,6 +112,6 @@ func SyncChannelCache(frequency int) {
|
|||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
common.SysLog("syncing channels from database")
|
common.SysLog("syncing channels from database")
|
||||||
model.ChannelGroup.Load()
|
model.ChannelGroup.Load()
|
||||||
util.PricingInstance.Init()
|
relay_util.PricingInstance.Init()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -109,10 +110,10 @@ func tokenAuth(c *gin.Context, key string) {
|
|||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
if strings.HasPrefix(parts[1], "!") {
|
if strings.HasPrefix(parts[1], "!") {
|
||||||
channelId := common.String2Int(parts[1][1:])
|
channelId := utils.String2Int(parts[1][1:])
|
||||||
c.Set("skip_channel_id", channelId)
|
c.Set("skip_channel_id", channelId)
|
||||||
} else {
|
} else {
|
||||||
channelId := common.String2Int(parts[1])
|
channelId := utils.String2Int(parts[1])
|
||||||
if channelId == 0 {
|
if channelId == 0 {
|
||||||
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -99,11 +100,11 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GlobalWebRateLimit() func(c *gin.Context) {
|
func GlobalWebRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.GetOrDefault("global.web_rate_limit", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW")
|
return rateLimitFactory(utils.GetOrDefault("global.web_rate_limit", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW")
|
||||||
}
|
}
|
||||||
|
|
||||||
func GlobalAPIRateLimit() func(c *gin.Context) {
|
func GlobalAPIRateLimit() func(c *gin.Context) {
|
||||||
return rateLimitFactory(common.GetOrDefault("global.api_rate_limit", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA")
|
return rateLimitFactory(utils.GetOrDefault("global.api_rate_limit", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA")
|
||||||
}
|
}
|
||||||
|
|
||||||
func CriticalRateLimit() func(c *gin.Context) {
|
func CriticalRateLimit() func(c *gin.Context) {
|
||||||
|
@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -10,7 +11,7 @@ import (
|
|||||||
|
|
||||||
func RequestId() func(c *gin.Context) {
|
func RequestId() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := common.GetTimeString() + common.GetRandomString(8)
|
id := utils.GetTimeString() + utils.GetRandomString(8)
|
||||||
c.Set(common.RequestIdKey, id)
|
c.Set(common.RequestIdKey, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||||
ctx = context.WithValue(ctx, "requestStartTime", time.Now())
|
ctx = context.WithValue(ctx, "requestStartTime", time.Now())
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": utils.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
"type": "one_api_error",
|
"type": "one_api_error",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -105,7 +106,7 @@ func (cc *ChannelsChooser) Next(group, modelName string, filters ...ChannelsFilt
|
|||||||
|
|
||||||
channelsPriority, ok := cc.Rule[group][modelName]
|
channelsPriority, ok := cc.Rule[group][modelName]
|
||||||
if !ok {
|
if !ok {
|
||||||
matchModel := common.GetModelsWithMatch(&cc.Match, modelName)
|
matchModel := utils.GetModelsWithMatch(&cc.Match, modelName)
|
||||||
channelsPriority, ok = cc.Rule[group][matchModel]
|
channelsPriority, ok = cc.Rule[group][matchModel]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("model not found")
|
return nil, errors.New("model not found")
|
||||||
@ -199,7 +200,7 @@ func (cc *ChannelsChooser) Load() {
|
|||||||
// 逗号分割 ability.ChannelId
|
// 逗号分割 ability.ChannelId
|
||||||
channelIds := strings.Split(ability.ChannelIds, ",")
|
channelIds := strings.Split(ability.ChannelIds, ",")
|
||||||
for _, channelId := range channelIds {
|
for _, channelId := range channelIds {
|
||||||
priorityIds = append(priorityIds, common.String2Int(channelId))
|
priorityIds = append(priorityIds, utils.String2Int(channelId))
|
||||||
}
|
}
|
||||||
|
|
||||||
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
|
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
@ -235,7 +236,7 @@ func (channel *Channel) UpdateRaw(overwrite bool) error {
|
|||||||
|
|
||||||
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||||||
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
||||||
TestTime: common.GetTimestamp(),
|
TestTime: utils.GetTimestamp(),
|
||||||
ResponseTime: int(responseTime),
|
ResponseTime: int(responseTime),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -245,7 +246,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|||||||
|
|
||||||
func (channel *Channel) UpdateBalance(balance float64) {
|
func (channel *Channel) UpdateBalance(balance float64) {
|
||||||
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
||||||
BalanceUpdatedTime: common.GetTimestamp(),
|
BalanceUpdatedTime: utils.GetTimestamp(),
|
||||||
Balance: balance,
|
Balance: balance,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -41,7 +42,7 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Username: GetUsernameById(userId),
|
Username: GetUsernameById(userId),
|
||||||
CreatedAt: common.GetTimestamp(),
|
CreatedAt: utils.GetTimestamp(),
|
||||||
Type: logType,
|
Type: logType,
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
@ -59,7 +60,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Username: GetUsernameById(userId),
|
Username: GetUsernameById(userId),
|
||||||
CreatedAt: common.GetTimestamp(),
|
CreatedAt: utils.GetTimestamp(),
|
||||||
Type: LogTypeConsume,
|
Type: LogTypeConsume,
|
||||||
Content: content,
|
Content: content,
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
|
@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -26,7 +27,7 @@ func SetupDB() {
|
|||||||
|
|
||||||
if viper.GetBool("batch_update_enabled") {
|
if viper.GetBool("batch_update_enabled") {
|
||||||
common.BatchUpdateEnabled = true
|
common.BatchUpdateEnabled = true
|
||||||
common.BatchUpdateInterval = common.GetOrDefault("batch_update_interval", 5)
|
common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
|
||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
InitBatchUpdater()
|
InitBatchUpdater()
|
||||||
}
|
}
|
||||||
@ -47,7 +48,7 @@ func createRootAccountIfNeed() error {
|
|||||||
Role: common.RoleRootUser,
|
Role: common.RoleRootUser,
|
||||||
Status: common.UserStatusEnabled,
|
Status: common.UserStatusEnabled,
|
||||||
DisplayName: "Root User",
|
DisplayName: "Root User",
|
||||||
AccessToken: common.GetUUID(),
|
AccessToken: utils.GetUUID(),
|
||||||
Quota: 100000000,
|
Quota: 100000000,
|
||||||
}
|
}
|
||||||
DB.Create(&rootUser)
|
DB.Create(&rootUser)
|
||||||
@ -78,7 +79,7 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
// Use SQLite
|
// Use SQLite
|
||||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||||
common.UsingSQLite = true
|
common.UsingSQLite = true
|
||||||
config := fmt.Sprintf("?_busy_timeout=%d", common.GetOrDefault("sqlite_busy_timeout", 3000))
|
config := fmt.Sprintf("?_busy_timeout=%d", utils.GetOrDefault("sqlite_busy_timeout", 3000))
|
||||||
return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{
|
return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{
|
||||||
PrepareStmt: true, // precompile SQL
|
PrepareStmt: true, // precompile SQL
|
||||||
})
|
})
|
||||||
@ -96,9 +97,9 @@ func InitDB() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
sqlDB.SetMaxIdleConns(utils.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||||
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||||
|
|
||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -33,7 +34,7 @@ func GetRedemptionsList(params *GenericParams) (*DataResult[Redemption], error)
|
|||||||
var redemptions []*Redemption
|
var redemptions []*Redemption
|
||||||
db := DB
|
db := DB
|
||||||
if params.Keyword != "" {
|
if params.Keyword != "" {
|
||||||
db = db.Where("id = ? or name LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%")
|
db = db.Where("id = ? or name LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
return PaginateAndOrder[Redemption](db, ¶ms.PaginationParams, &redemptions, allowedRedemptionslOrderFields)
|
return PaginateAndOrder[Redemption](db, ¶ms.PaginationParams, &redemptions, allowedRedemptionslOrderFields)
|
||||||
@ -75,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
redemption.RedeemedTime = common.GetTimestamp()
|
redemption.RedeemedTime = utils.GetTimestamp()
|
||||||
redemption.Status = common.RedemptionCodeStatusUsed
|
redemption.Status = common.RedemptionCodeStatusUsed
|
||||||
err = tx.Save(redemption).Error
|
err = tx.Save(redemption).Error
|
||||||
return err
|
return err
|
||||||
|
@ -2,7 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"one-api/common"
|
"one-api/common/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TelegramMenu struct {
|
type TelegramMenu struct {
|
||||||
@ -22,7 +22,7 @@ func GetTelegramMenusList(params *GenericParams) (*DataResult[TelegramMenu], err
|
|||||||
var menus []*TelegramMenu
|
var menus []*TelegramMenu
|
||||||
db := DB
|
db := DB
|
||||||
if params.Keyword != "" {
|
if params.Keyword != "" {
|
||||||
db = db.Where("id = ? or command LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%")
|
db = db.Where("id = ? or command LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
return PaginateAndOrder[TelegramMenu](db, ¶ms.PaginationParams, &menus, allowedTelegramMenusOrderFields)
|
return PaginateAndOrder[TelegramMenu](db, ¶ms.PaginationParams, &menus, allowedTelegramMenusOrderFields)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/stmp"
|
"one-api/common/stmp"
|
||||||
|
"one-api/common/utils"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -71,7 +72,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
if token.Status != common.TokenStatusEnabled {
|
if token.Status != common.TokenStatusEnabled {
|
||||||
return nil, errors.New("该令牌状态不可用")
|
return nil, errors.New("该令牌状态不可用")
|
||||||
}
|
}
|
||||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
token.Status = common.TokenStatusExpired
|
token.Status = common.TokenStatusExpired
|
||||||
err := token.SelectUpdate()
|
err := token.SelectUpdate()
|
||||||
@ -188,7 +189,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
|
|||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||||
"accessed_time": common.GetTimestamp(),
|
"accessed_time": utils.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
@ -210,7 +211,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
"accessed_time": common.GetTimestamp(),
|
"accessed_time": utils.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -55,7 +56,7 @@ func GetUsersList(params *GenericParams) (*DataResult[User], error) {
|
|||||||
var users []*User
|
var users []*User
|
||||||
db := DB.Omit("password")
|
db := DB.Omit("password")
|
||||||
if params.Keyword != "" {
|
if params.Keyword != "" {
|
||||||
db = db.Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", common.String2Int(params.Keyword), params.Keyword+"%", params.Keyword+"%", params.Keyword+"%")
|
db = db.Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", utils.String2Int(params.Keyword), params.Keyword+"%", params.Keyword+"%", params.Keyword+"%")
|
||||||
}
|
}
|
||||||
|
|
||||||
return PaginateAndOrder[User](db, ¶ms.PaginationParams, &users, allowedUserOrderFields)
|
return PaginateAndOrder[User](db, ¶ms.PaginationParams, &users, allowedUserOrderFields)
|
||||||
@ -115,9 +116,9 @@ func (user *User) Insert(inviterId int) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
user.Quota = common.QuotaForNewUser
|
user.Quota = common.QuotaForNewUser
|
||||||
user.AccessToken = common.GetUUID()
|
user.AccessToken = utils.GetUUID()
|
||||||
user.AffCode = common.GetRandomString(4)
|
user.AffCode = utils.GetRandomString(4)
|
||||||
user.CreatedTime = common.GetTimestamp()
|
user.CreatedTime = utils.GetTimestamp()
|
||||||
result := DB.Create(user)
|
result := DB.Create(user)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
@ -165,7 +166,7 @@ func (user *User) Delete() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 不改变当前数据库索引,通过更改用户名来删除用户
|
// 不改变当前数据库索引,通过更改用户名来删除用户
|
||||||
user.Username = user.Username + "_del_" + common.GetRandomString(6)
|
user.Username = user.Username + "_del_" + utils.GetRandomString(6)
|
||||||
err := user.Update(false)
|
err := user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -92,7 +93,7 @@ func (p *AliProvider) convertToChatOpenai(response *AliChatResponse, request *ty
|
|||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: response.RequestId,
|
ID: response.RequestId,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Choices: response.Output.ToChatCompletionChoices(),
|
Choices: response.Output.ToChatCompletionChoices(),
|
||||||
Usage: &types.Usage{
|
Usage: &types.Usage{
|
||||||
@ -223,7 +224,7 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, d
|
|||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: aliResponse.RequestId,
|
ID: aliResponse.RequestId,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/image"
|
"one-api/common/image"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
@ -172,7 +173,7 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeRespon
|
|||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: response.Id,
|
ID: response.Id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Choices: []types.ChatCompletionChoice{choice},
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Usage: &types.Usage{
|
Usage: &types.Usage{
|
||||||
@ -264,9 +265,9 @@ func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStream
|
|||||||
choice.FinishReason = &finishReason
|
choice.FinishReason = &finishReason
|
||||||
}
|
}
|
||||||
chatCompletion := types.ChatCompletionStreamResponse{
|
chatCompletion := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -85,9 +86,9 @@ func (p *CloudflareAIProvider) convertToChatOpenai(response *ChatRespone, reques
|
|||||||
}
|
}
|
||||||
|
|
||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Choices: []types.ChatCompletionChoice{{
|
Choices: []types.ChatCompletionChoice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
@ -155,9 +156,9 @@ func (h *CloudflareAIStreamHandler) handlerStream(rawLine *[]byte, dataChan chan
|
|||||||
|
|
||||||
func (h *CloudflareAIStreamHandler) convertToOpenaiStream(chatResponse *ChatResult, dataChan chan string, isStop bool) {
|
func (h *CloudflareAIStreamHandler) convertToOpenaiStream(chatResponse *ChatResult, dataChan chan string, isStop bool) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/storage"
|
"one-api/common/storage"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -46,7 +47,7 @@ func (p *CloudflareAIProvider) CreateImageGenerations(request *types.ImageReques
|
|||||||
|
|
||||||
url := ""
|
url := ""
|
||||||
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
||||||
url = storage.Upload(body, common.GetUUID()+".png")
|
url = storage.Upload(body, utils.GetUUID()+".png")
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiResponse := &types.ImageResponse{
|
openaiResponse := &types.ImageResponse{
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/providers/base"
|
"one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
@ -138,7 +139,7 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *CohereRespon
|
|||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: response.GenerationID,
|
ID: response.GenerationID,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Choices: []types.ChatCompletionChoice{choice},
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Usage: &types.Usage{},
|
Usage: &types.Usage{},
|
||||||
@ -190,9 +191,9 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream
|
|||||||
}
|
}
|
||||||
|
|
||||||
chatCompletion := types.ChatCompletionStreamResponse{
|
chatCompletion := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -89,9 +90,9 @@ func (p *CozeProvider) convertToChatOpenai(response *CozeResponse, request *type
|
|||||||
}
|
}
|
||||||
|
|
||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Choices: []types.ChatCompletionChoice{{
|
Choices: []types.ChatCompletionChoice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
@ -168,9 +169,9 @@ func (h *CozeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string,
|
|||||||
|
|
||||||
func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) {
|
func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -145,9 +146,9 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
|
|||||||
}
|
}
|
||||||
|
|
||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
||||||
}
|
}
|
||||||
@ -191,9 +192,9 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
|
|||||||
|
|
||||||
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) {
|
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
// Choices: choices,
|
// Choices: choices,
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/image"
|
"one-api/common/image"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,7 +121,7 @@ func (g *GeminiFunctionCall) ToOpenAITool() *types.ChatCompletionToolCalls {
|
|||||||
args, _ := json.Marshal(g.Args)
|
args, _ := json.Marshal(g.Args)
|
||||||
|
|
||||||
return &types.ChatCompletionToolCalls{
|
return &types.ChatCompletionToolCalls{
|
||||||
Id: "call_" + common.GetRandomString(24),
|
Id: "call_" + utils.GetRandomString(24),
|
||||||
Type: types.ChatMessageRoleFunction,
|
Type: types.ChatMessageRoleFunction,
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Function: &types.ChatCompletionToolCallsFunction{
|
Function: &types.ChatCompletionToolCallsFunction{
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/image"
|
"one-api/common/image"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -100,9 +101,9 @@ func (p *OllamaProvider) convertToChatOpenai(response *ChatResponse, request *ty
|
|||||||
}
|
}
|
||||||
|
|
||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Choices: []types.ChatCompletionChoice{choices},
|
Choices: []types.ChatCompletionChoice{choices},
|
||||||
Usage: &types.Usage{
|
Usage: &types.Usage{
|
||||||
@ -195,9 +196,9 @@ func (h *ollamaStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
chatCompletion := types.ChatCompletionStreamResponse{
|
chatCompletion := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -177,11 +178,11 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp
|
|||||||
choice.FinishReason = types.FinishReasonStop
|
choice.FinishReason = types.FinishReasonStop
|
||||||
|
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
}
|
}
|
||||||
|
|
||||||
responseBody, _ := json.Marshal(streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/storage"
|
"one-api/common/storage"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -71,7 +72,7 @@ func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest
|
|||||||
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
||||||
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
|
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
imgUrl = storage.Upload(body, common.GetUUID()+".png")
|
imgUrl = storage.Upload(body, utils.GetUUID()+".png")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -101,7 +102,7 @@ func (p *TencentProvider) convertToChatOpenai(response *TencentChatResponse, req
|
|||||||
|
|
||||||
openaiResponse = &types.ChatCompletionResponse{
|
openaiResponse = &types.ChatCompletionResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Usage: response.Usage,
|
Usage: response.Usage,
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
}
|
}
|
||||||
@ -137,9 +138,9 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq
|
|||||||
stream = 1
|
stream = 1
|
||||||
}
|
}
|
||||||
return &TencentChatRequest{
|
return &TencentChatRequest{
|
||||||
Timestamp: common.GetTimestamp(),
|
Timestamp: utils.GetTimestamp(),
|
||||||
Expired: common.GetTimestamp() + 24*60*60,
|
Expired: utils.GetTimestamp() + 24*60*60,
|
||||||
QueryID: common.GetUUID(),
|
QueryID: utils.GetUUID(),
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
@ -178,7 +179,7 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri
|
|||||||
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string) {
|
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
}
|
}
|
||||||
if len(tencentChatResponse.Choices) > 0 {
|
if len(tencentChatResponse.Choices) > 0 {
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -194,7 +195,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa
|
|||||||
ID: xunfeiResponse.Header.Sid,
|
ID: xunfeiResponse.Header.Sid,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Choices: []types.ChatCompletionChoice{choice},
|
Choices: []types.ChatCompletionChoice{choice},
|
||||||
Usage: &xunfeiResponse.Payload.Usage.Text,
|
Usage: &xunfeiResponse.Payload.Usage.Text,
|
||||||
}
|
}
|
||||||
@ -310,7 +311,7 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp
|
|||||||
chatCompletion := types.ChatCompletionStreamResponse{
|
chatCompletion := types.ChatCompletionStreamResponse{
|
||||||
ID: xunfeiChatResponse.Header.Sid,
|
ID: xunfeiChatResponse.Header.Sid,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
providersBase "one-api/providers/base"
|
providersBase "one-api/providers/base"
|
||||||
@ -14,7 +14,7 @@ type relayBase struct {
|
|||||||
provider providersBase.ProviderInterface
|
provider providersBase.ProviderInterface
|
||||||
originalModel string
|
originalModel string
|
||||||
modelName string
|
modelName string
|
||||||
cache *util.ChatCacheProps
|
cache *relay_util.ChatCacheProps
|
||||||
}
|
}
|
||||||
|
|
||||||
type RelayBaseInterface interface {
|
type RelayBaseInterface interface {
|
||||||
@ -28,14 +28,14 @@ type RelayBaseInterface interface {
|
|||||||
getModelName() string
|
getModelName() string
|
||||||
getContext() *gin.Context
|
getContext() *gin.Context
|
||||||
SetChatCache(allow bool)
|
SetChatCache(allow bool)
|
||||||
GetChatCache() *util.ChatCacheProps
|
GetChatCache() *relay_util.ChatCacheProps
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relayBase) SetChatCache(allow bool) {
|
func (r *relayBase) SetChatCache(allow bool) {
|
||||||
r.cache = util.NewChatCacheProps(r.c, allow)
|
r.cache = relay_util.NewChatCacheProps(r.c, allow)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relayBase) GetChatCache() *util.ChatCacheProps {
|
func (r *relayBase) GetChatCache() *relay_util.ChatCacheProps {
|
||||||
return r.cache
|
return r.cache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
providersBase "one-api/providers/base"
|
providersBase "one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
@ -100,9 +101,9 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
|||||||
func (r *relayChat) getUsageResponse() string {
|
func (r *relayChat) getUsageResponse() string {
|
||||||
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
|
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
|
||||||
usageResponse := types.ChatCompletionStreamResponse{
|
usageResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: r.chatRequest.Model,
|
Model: r.chatRequest.Model,
|
||||||
Choices: []types.ChatCompletionStreamChoice{},
|
Choices: []types.ChatCompletionStreamChoice{},
|
||||||
Usage: r.provider.GetUsage(),
|
Usage: r.provider.GetUsage(),
|
||||||
|
@ -9,11 +9,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers"
|
"one-api/providers"
|
||||||
providersBase "one-api/providers/base"
|
providersBase "one-api/providers/base"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -142,7 +143,7 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
|
|||||||
|
|
||||||
type StreamEndHandler func() string
|
type StreamEndHandler func() string
|
||||||
|
|
||||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *relay_util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
||||||
requester.SetEventStreamHeaders(c)
|
requester.SetEventStreamHeaders(c)
|
||||||
dataChan, errChan := stream.Recv()
|
dataChan, errChan := stream.Recv()
|
||||||
|
|
||||||
@ -257,7 +258,7 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
|
|||||||
|
|
||||||
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
err.OpenAIError.Message = utils.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||||
c.JSON(err.StatusCode, gin.H{
|
c.JSON(err.StatusCode, gin.H{
|
||||||
"error": err.OpenAIError,
|
"error": err.OpenAIError,
|
||||||
})
|
})
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/common/utils"
|
||||||
providersBase "one-api/providers/base"
|
providersBase "one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
@ -93,9 +94,9 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
|||||||
func (r *relayCompletions) getUsageResponse() string {
|
func (r *relayCompletions) getUsageResponse() string {
|
||||||
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
|
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
|
||||||
usageResponse := types.CompletionResponse{
|
usageResponse := types.CompletionResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", utils.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: utils.GetTimestamp(),
|
||||||
Model: r.request.Model,
|
Model: r.request.Model,
|
||||||
Choices: []types.CompletionChoice{},
|
Choices: []types.CompletionChoice{},
|
||||||
Usage: r.provider.GetUsage(),
|
Usage: r.provider.GetUsage(),
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -96,8 +96,8 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
|
|||||||
|
|
||||||
relay.getProvider().SetUsage(usage)
|
relay.getProvider().SetUsage(usage)
|
||||||
|
|
||||||
var quota *util.Quota
|
var quota *relay_util.Quota
|
||||||
quota, err = util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens)
|
quota, err = relay_util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
done = true
|
done = true
|
||||||
return
|
return
|
||||||
@ -119,7 +119,7 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func cacheProcessing(c *gin.Context, cacheProps *util.ChatCacheProps) {
|
func cacheProcessing(c *gin.Context, cacheProps *relay_util.ChatCacheProps) {
|
||||||
responseCache(c, cacheProps.Response)
|
responseCache(c, cacheProps.Response)
|
||||||
|
|
||||||
// 写入日志
|
// 写入日志
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
provider "one-api/providers/midjourney"
|
provider "one-api/providers/midjourney"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -539,10 +539,10 @@ func getMjRequestPath(path string) string {
|
|||||||
return requestURL
|
return requestURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
|
func getQuota(c *gin.Context, action string) (*relay_util.Quota, *types.OpenAIErrorWithStatusCode) {
|
||||||
modelName := CoverActionToModelName(action)
|
modelName := CoverActionToModelName(action)
|
||||||
|
|
||||||
return util.NewQuota(c, modelName, 1000)
|
return relay_util.NewQuota(c, modelName, 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/util"
|
"one-api/relay/relay_util"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ func ListModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ListModelsForAdmin(c *gin.Context) {
|
func ListModelsForAdmin(c *gin.Context) {
|
||||||
prices := util.PricingInstance.GetAllPrices()
|
prices := relay_util.PricingInstance.GetAllPrices()
|
||||||
var openAIModels []OpenAIModels
|
var openAIModels []OpenAIModels
|
||||||
for modelId, price := range prices {
|
for modelId, price := range prices {
|
||||||
openAIModels = append(openAIModels, OpenAIModels{
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
@ -123,7 +123,7 @@ func ListModelsForAdmin(c *gin.Context) {
|
|||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context) {
|
||||||
modelName := c.Param("model")
|
modelName := c.Param("model")
|
||||||
openaiModel := getOpenAIModelWithName(modelName)
|
openaiModel := getOpenAIModelWithName(modelName)
|
||||||
if *openaiModel.OwnedBy != util.UnknownOwnedBy {
|
if *openaiModel.OwnedBy != relay_util.UnknownOwnedBy {
|
||||||
c.JSON(200, openaiModel)
|
c.JSON(200, openaiModel)
|
||||||
} else {
|
} else {
|
||||||
openAIError := types.OpenAIError{
|
openAIError := types.OpenAIError{
|
||||||
@ -139,15 +139,15 @@ func RetrieveModel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getModelOwnedBy(channelType int) (ownedBy *string) {
|
func getModelOwnedBy(channelType int) (ownedBy *string) {
|
||||||
if ownedByName, ok := util.ModelOwnedBy[channelType]; ok {
|
if ownedByName, ok := relay_util.ModelOwnedBy[channelType]; ok {
|
||||||
return &ownedByName
|
return &ownedByName
|
||||||
}
|
}
|
||||||
|
|
||||||
return &util.UnknownOwnedBy
|
return &relay_util.UnknownOwnedBy
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOpenAIModelWithName(modelName string) *OpenAIModels {
|
func getOpenAIModelWithName(modelName string) *OpenAIModels {
|
||||||
price := util.PricingInstance.GetPrice(modelName)
|
price := relay_util.PricingInstance.GetPrice(modelName)
|
||||||
|
|
||||||
return &OpenAIModels{
|
return &OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
@ -164,6 +164,6 @@ func GetModelOwnedBy(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": util.ModelOwnedBy,
|
"data": relay_util.ModelOwnedBy,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
package util
|
package relay_util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -37,7 +38,7 @@ func GetDebugList(userId int) ([]*ChatCacheProps, error) {
|
|||||||
|
|
||||||
var props []*ChatCacheProps
|
var props []*ChatCacheProps
|
||||||
for _, cache := range caches {
|
for _, cache := range caches {
|
||||||
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
|
prop, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -77,7 +78,7 @@ func (p *ChatCacheProps) SetHash(request any) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.hash(common.Marshal(request))
|
p.hash(utils.Marshal(request))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ChatCacheProps) SetResponse(response any) {
|
func (p *ChatCacheProps) SetResponse(response any) {
|
||||||
@ -90,7 +91,7 @@ func (p *ChatCacheProps) SetResponse(response any) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
responseStr := common.Marshal(response)
|
responseStr := utils.Marshal(response)
|
||||||
if responseStr == "" {
|
if responseStr == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
@ -1,8 +1,8 @@
|
|||||||
package util
|
package relay_util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"one-api/common"
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -15,7 +15,7 @@ func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
|
props, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -28,7 +28,7 @@ func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
|
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
|
||||||
data := common.Marshal(props)
|
data := utils.Marshal(props)
|
||||||
if data == "" {
|
if data == "" {
|
||||||
return errors.New("marshal error")
|
return errors.New("marshal error")
|
||||||
}
|
}
|
@ -1,9 +1,10 @@
|
|||||||
package util
|
package relay_util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
props, err := common.UnmarshalString[ChatCacheProps](cache)
|
props, err := utils.UnmarshalString[ChatCacheProps](cache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -31,7 +32,7 @@ func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data := common.Marshal(&props)
|
data := utils.Marshal(&props)
|
||||||
if data == "" {
|
if data == "" {
|
||||||
return errors.New("marshal error")
|
return errors.New("marshal error")
|
||||||
}
|
}
|
@ -1,9 +1,10 @@
|
|||||||
package util
|
package relay_util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/utils"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@ -98,7 +99,7 @@ func (p *Pricing) GetPrice(modelName string) *model.Price {
|
|||||||
return price
|
return price
|
||||||
}
|
}
|
||||||
|
|
||||||
matchModel := common.GetModelsWithMatch(&p.Match, modelName)
|
matchModel := utils.GetModelsWithMatch(&p.Match, modelName)
|
||||||
if price, ok := p.Prices[matchModel]; ok {
|
if price, ok := p.Prices[matchModel]; ok {
|
||||||
return price
|
return price
|
||||||
}
|
}
|
||||||
@ -281,7 +282,7 @@ func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []stri
|
|||||||
var updatePrices []string
|
var updatePrices []string
|
||||||
|
|
||||||
for _, model := range originalModels {
|
for _, model := range originalModels {
|
||||||
if !common.Contains(model, batchPrices.Models) {
|
if !utils.Contains(model, batchPrices.Models) {
|
||||||
deletePrices = append(deletePrices, model)
|
deletePrices = append(deletePrices, model)
|
||||||
} else {
|
} else {
|
||||||
updatePrices = append(updatePrices, model)
|
updatePrices = append(updatePrices, model)
|
||||||
@ -289,7 +290,7 @@ func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, model := range batchPrices.Models {
|
for _, model := range batchPrices.Models {
|
||||||
if !common.Contains(model, originalModels) {
|
if !utils.Contains(model, originalModels) {
|
||||||
addPrice := batchPrices.Price
|
addPrice := batchPrices.Price
|
||||||
addPrice.Model = model
|
addPrice.Model = model
|
||||||
addPrices = append(addPrices, &addPrice)
|
addPrices = append(addPrices, &addPrice)
|
@ -1,4 +1,4 @@
|
|||||||
package util
|
package relay_util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
@ -1,4 +1,4 @@
|
|||||||
package util
|
package relay_util
|
||||||
|
|
||||||
import "one-api/common"
|
import "one-api/common"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user