diff --git a/common/email.go b/common/email.go deleted file mode 100644 index b915f0f9..00000000 --- a/common/email.go +++ /dev/null @@ -1,86 +0,0 @@ -package common - -import ( - "crypto/rand" - "crypto/tls" - "encoding/base64" - "fmt" - "net/smtp" - "strings" - "time" -) - -func SendEmail(subject string, receiver string, content string) error { - if SMTPFrom == "" { // for compatibility - SMTPFrom = SMTPAccount - } - encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) - - // Extract domain from SMTPFrom - parts := strings.Split(SMTPFrom, "@") - var domain string - if len(parts) > 1 { - domain = parts[1] - } - // Generate a unique Message-ID - buf := make([]byte, 16) - _, err := rand.Read(buf) - if err != nil { - return err - } - messageId := fmt.Sprintf("<%x@%s>", buf, domain) - - mail := []byte(fmt.Sprintf("To: %s\r\n"+ - "From: %s<%s>\r\n"+ - "Subject: %s\r\n"+ - "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 - "Date: %s\r\n"+ - "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) - auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) - addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) - to := strings.Split(receiver, ";") - - if SMTPPort == 465 { - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - ServerName: SMTPServer, - } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) - if err != nil { - return err - } - client, err := smtp.NewClient(conn, SMTPServer) - if err != nil { - return err - } - defer client.Close() - if err = client.Auth(auth); err != nil { - return err - } - if err = client.Mail(SMTPFrom); err != nil { - return err - } - receiverEmails := strings.Split(receiver, ";") - for _, receiver := range receiverEmails { - if err = client.Rcpt(receiver); err != nil { - return err - } - } - w, err := client.Data() - if err != nil { - return err - } - _, err = w.Write(mail) - if err != nil { - return err - } - err = w.Close() - if err != nil { - return err - } - } else { - err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) - } - return err -} diff --git a/common/notify/channel/channel_test.go b/common/notify/channel/channel_test.go new file mode 100644 index 00000000..79a9bdc6 --- /dev/null +++ b/common/notify/channel/channel_test.go @@ -0,0 +1,128 @@ +package channel_test + +import ( + "context" + "fmt" + "testing" + + "one-api/common/notify/channel" + "one-api/common/requester" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func InitConfig() { + viper.AddConfigPath("/one-api") + viper.SetConfigName("config") + viper.ReadInConfig() + requester.InitHttpClient() +} + +func TestDingTalkSend(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.dingtalk.token") + secret := viper.GetString("notify.dingtalk.secret") + dingTalk := channel.NewDingTalk(access_token, secret) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestDingTalkSendWithKeyWord(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.dingtalk.token") + keyWord := viper.GetString("notify.dingtalk.keyWord") + + dingTalk := channel.NewDingTalkWithKeyWord(access_token, keyWord) + + err := dingTalk.Send(context.Background(), "Test Title", "Test Message") + assert.Nil(t, err) +} + +func TestDingTalkSendError(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.dingtalk.token") + secret := "test" + dingTalk := channel.NewDingTalk(access_token, secret) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Error(t, err) +} + +func TestLarkSend(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.lark.token") + secret := viper.GetString("notify.lark.secret") + dingTalk := channel.NewLark(access_token, secret) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestLarkSendWithKeyWord(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.lark.token") + keyWord := viper.GetString("notify.lark.keyWord") + + dingTalk := channel.NewLarkWithKeyWord(access_token, keyWord) + + err := dingTalk.Send(context.Background(), "Test Title", "Test Message\n\n- 111\n- 222") + assert.Nil(t, err) +} + +func TestLarkSendError(t *testing.T) { + InitConfig() + access_token := viper.GetString("notify.lark.token") + secret := "test" + dingTalk := channel.NewLark(access_token, secret) + + err := dingTalk.Send(context.Background(), "Title", "*Message*") + fmt.Println(err) + assert.Error(t, err) +} + +func TestPushdeerSend(t *testing.T) { + InitConfig() + pushkey := viper.GetString("notify.pushdeer.pushkey") + dingTalk := channel.NewPushdeer(pushkey, "") + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestPushdeerSendError(t *testing.T) { + InitConfig() + pushkey := "test" + dingTalk := channel.NewPushdeer(pushkey, "") + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Error(t, err) +} + +func TestTelegramSend(t *testing.T) { + InitConfig() + secret := viper.GetString("notify.telegram.bot_api_key") + chatID := viper.GetString("notify.telegram.chat_id") + dingTalk := channel.NewTelegram(secret, chatID) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Nil(t, err) +} + +func TestTelegramSendError(t *testing.T) { + InitConfig() + secret := "test" + chatID := viper.GetString("notify.telegram.chat_id") + dingTalk := channel.NewTelegram(secret, chatID) + + err := dingTalk.Send(context.Background(), "Test Title", "*Test Message*") + fmt.Println(err) + assert.Error(t, err) +} diff --git a/common/notify/channel/dingTalk.go b/common/notify/channel/dingTalk.go new file mode 100644 index 00000000..d20beddf --- /dev/null +++ b/common/notify/channel/dingTalk.go @@ -0,0 +1,127 @@ +package channel + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "one-api/common/requester" + "one-api/types" + "time" +) + +const dingTalkURL = "https://oapi.dingtalk.com/robot/send?" + +type DingTalk struct { + token string + secret string + keyWord string +} + +type dingTalkMessage struct { + MsgType string `json:"msgtype"` + Markdown struct { + Title string `json:"title"` + Text string `json:"text"` + } `json:"markdown"` +} + +type dingTalkResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +func NewDingTalk(token string, secret string) *DingTalk { + return &DingTalk{ + token: token, + secret: secret, + } +} + +func NewDingTalkWithKeyWord(token string, keyWord string) *DingTalk { + return &DingTalk{ + token: token, + keyWord: keyWord, + } +} + +func (d *DingTalk) Name() string { + return "DingTalk" +} + +func (d *DingTalk) Send(ctx context.Context, title, message string) error { + msg := dingTalkMessage{ + MsgType: "markdown", + } + msg.Markdown.Title = title + msg.Markdown.Text = message + + if d.keyWord != "" { + msg.Markdown.Text = fmt.Sprintf("%s\n%s", d.keyWord, msg.Markdown.Text) + } + + query := url.Values{} + query.Set("access_token", d.token) + if d.secret != "" { + t := time.Now().UnixMilli() + query.Set("timestamp", fmt.Sprintf("%d", t)) + query.Set("sign", d.sign(t)) + } + uri := dingTalkURL + query.Encode() + + client := requester.NewHTTPRequester("", dingtalkErrFunc) + client.Context = ctx + client.IsOpenAI = false + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + resp, errWithOP := client.SendRequestRaw(req) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + defer resp.Body.Close() + + dingtalkErr := dingtalkErrFunc(resp) + if dingtalkErr != nil { + return fmt.Errorf("%s", dingtalkErr.Message) + } + + return nil +} + +func (d *DingTalk) sign(timestamp int64) string { + stringToHash := fmt.Sprintf("%d\n%s", timestamp, d.secret) + hmac256 := hmac.New(sha256.New, []byte(d.secret)) + hmac256.Write([]byte(stringToHash)) + data := hmac256.Sum(nil) + signature := base64.StdEncoding.EncodeToString(data) + + return url.QueryEscape(signature) +} + +func dingtalkErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &dingTalkResponse{} + + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + fmt.Println(err) + return nil + } + + if respMsg.ErrCode == 0 { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.ErrMsg), + Type: "dingtalk_error", + Code: fmt.Sprintf("%d", respMsg.ErrCode), + } +} diff --git a/common/notify/channel/email.go b/common/notify/channel/email.go new file mode 100644 index 00000000..6fde98a5 --- /dev/null +++ b/common/notify/channel/email.go @@ -0,0 +1,50 @@ +package channel + +import ( + "context" + "errors" + "one-api/common" + "one-api/common/stmp" + + "github.com/gomarkdown/markdown" + "github.com/gomarkdown/markdown/html" + "github.com/gomarkdown/markdown/parser" +) + +type Email struct { + To string +} + +func NewEmail(to string) *Email { + return &Email{ + To: to, + } +} + +func (e *Email) Name() string { + return "Email" +} + +func (e *Email) Send(ctx context.Context, title, message string) error { + to := e.To + if to == "" { + to = common.RootUserEmail + } + + if common.SMTPServer == "" || common.SMTPAccount == "" || common.SMTPToken == "" || to == "" { + return errors.New("smtp config is not set, skip send email notifier") + } + + p := parser.NewWithExtensions(parser.CommonExtensions | parser.DefinitionLists | parser.OrderedListStart) + doc := p.Parse([]byte(message)) + + htmlFlags := html.CommonFlags | html.HrefTargetBlank + opts := html.RendererOptions{Flags: htmlFlags} + renderer := html.NewRenderer(opts) + + body := markdown.Render(doc, renderer) + + emailClient := stmp.NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom) + + return emailClient.Send(to, title, string(body)) +} diff --git a/common/notify/channel/lark.go b/common/notify/channel/lark.go new file mode 100644 index 00000000..cddcfc1b --- /dev/null +++ b/common/notify/channel/lark.go @@ -0,0 +1,149 @@ +package channel + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/types" + "strconv" + "time" +) + +const larkURL = "https://open.feishu.cn/open-apis/bot/v2/hook/" + +type Lark struct { + token string + secret string + keyWord string +} + +type larkMessage struct { + MessageType string `json:"msg_type"` + Timestamp string `json:"timestamp,omitempty"` + Sign string `json:"sign,omitempty"` + Card larkCardContent `json:"card"` +} + +type larkCardContent struct { + Config struct { + WideScreenMode bool `json:"wide_screen_mode"` + EnableForward bool `json:"enable_forward"` + } + Elements []larkMessageRequestCardElement `json:"elements"` +} + +type larkMessageRequestCardElementText struct { + Content string `json:"content"` + Tag string `json:"tag"` +} + +type larkMessageRequestCardElement struct { + Tag string `json:"tag"` + Text larkMessageRequestCardElementText `json:"text"` +} + +type larkResponse struct { + Code int `json:"code"` + Message string `json:"msg"` +} + +func NewLark(token, secret string) *Lark { + return &Lark{ + token: token, + secret: secret, + } +} + +func NewLarkWithKeyWord(token, keyWord string) *Lark { + return &Lark{ + token: token, + keyWord: keyWord, + } +} + +func (l *Lark) Name() string { + return "Lark" +} + +func (l *Lark) Send(ctx context.Context, title, message string) error { + msg := larkMessage{ + MessageType: "interactive", + } + + if l.keyWord != "" { + title = fmt.Sprintf("%s(%s)", title, l.keyWord) + } + + msg.Card.Config.WideScreenMode = true + msg.Card.Config.EnableForward = true + msg.Card.Elements = append(msg.Card.Elements, larkMessageRequestCardElement{ + Tag: "div", + Text: larkMessageRequestCardElementText{ + Content: fmt.Sprintf("**%s**\n%s", title, message), + Tag: "lark_md", + }, + }) + + if l.secret != "" { + t := time.Now().Unix() + msg.Timestamp = strconv.FormatInt(t, 10) + msg.Sign = l.sign(t) + } + + uri := larkURL + l.token + client := requester.NewHTTPRequester("", larkErrFunc) + client.Context = ctx + client.IsOpenAI = false + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + resp, errWithOP := client.SendRequestRaw(req) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + defer resp.Body.Close() + + larkErr := larkErrFunc(resp) + if larkErr != nil { + return fmt.Errorf("%s", larkErr.Message) + } + + return nil + +} + +func (l *Lark) sign(timestamp int64) string { + //timestamp + key 做sha256, 再进行base64 encode + stringToSign := fmt.Sprintf("%v", timestamp) + "\n" + l.secret + var data []byte + h := hmac.New(sha256.New, []byte(stringToSign)) + h.Write(data) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func larkErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &larkResponse{} + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + return nil + } + + if respMsg.Code == 0 { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.Message), + Type: "lark_error", + Code: respMsg.Code, + } +} diff --git a/common/notify/channel/pushdeer.go b/common/notify/channel/pushdeer.go new file mode 100644 index 00000000..20a5f1b2 --- /dev/null +++ b/common/notify/channel/pushdeer.go @@ -0,0 +1,96 @@ +package channel + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/types" + "strings" +) + +const pushdeerURL = "https://api2.pushdeer.com" + +type Pushdeer struct { + url string + pushkey string +} + +type pushdeerMessage struct { + Text string `json:"text"` + Desp string `json:"desp"` + Type string `json:"type"` +} + +type pushdeerResponse struct { + Code int `json:"code,omitempty"` + Error string `json:"error,omitempty"` + Message string `json:"message,omitempty"` +} + +func NewPushdeer(pushkey, url string) *Pushdeer { + return &Pushdeer{ + url: url, + pushkey: pushkey, + } +} + +func (p *Pushdeer) Name() string { + return "Pushdeer" +} + +func (p *Pushdeer) Send(ctx context.Context, title, message string) error { + msg := pushdeerMessage{ + Text: title, + Desp: message, + Type: "markdown", + } + + url := p.url + if url == "" { + url = pushdeerURL + } + + // 去除最后一个/ + url = strings.TrimSuffix(url, "/") + uri := fmt.Sprintf("%s/message/push?pushkey=%s", url, p.pushkey) + + client := requester.NewHTTPRequester("", pushdeerErrFunc) + client.Context = ctx + client.IsOpenAI = false + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + respMsg := &pushdeerResponse{} + _, errWithOP := client.SendRequest(req, respMsg, false) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + + if respMsg.Code != 0 { + return fmt.Errorf("send msg err. err msg: %s", respMsg.Error) + } + + return nil +} + +func pushdeerErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &pushdeerResponse{} + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + return nil + } + + if respMsg.Message == "" { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.Message), + Type: "pushdeer_error", + } +} diff --git a/common/notify/channel/telegram.go b/common/notify/channel/telegram.go new file mode 100644 index 00000000..26889479 --- /dev/null +++ b/common/notify/channel/telegram.go @@ -0,0 +1,114 @@ +package channel + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/types" +) + +const telegramURL = "https://api.telegram.org/bot" + +type Telegram struct { + secret string + chatID string +} + +type telegramMessage struct { + ChatID string `json:"chat_id"` + Text string `json:"text"` + ParseMode string `json:"parse_mode"` +} + +type telegramResponse struct { + Ok bool `json:"ok"` + Description string `json:"description"` +} + +func NewTelegram(secret string, chatID string) *Telegram { + return &Telegram{ + secret: secret, + chatID: chatID, + } +} + +func (t *Telegram) Name() string { + return "Telegram" +} + +func (t *Telegram) Send(ctx context.Context, title, message string) error { + const maxMessageLength = 4096 + message = fmt.Sprintf("*%s*\n%s", title, message) + messages := splitTelegramMessageIntoParts(message, maxMessageLength) + + client := requester.NewHTTPRequester("", telegramErrFunc) + client.Context = ctx + client.IsOpenAI = false + + for _, msg := range messages { + err := t.sendMessage(msg, client) + if err != nil { + return err + } + } + + return nil +} + +func (t *Telegram) sendMessage(message string, client *requester.HTTPRequester) error { + msg := telegramMessage{ + ChatID: t.chatID, + Text: message, + ParseMode: "Markdown", + } + + uri := telegramURL + t.secret + "/sendMessage" + + req, err := client.NewRequest(http.MethodPost, uri, client.WithHeader(requester.GetJsonHeaders()), client.WithBody(msg)) + if err != nil { + return err + } + + resp, errWithOP := client.SendRequestRaw(req) + if errWithOP != nil { + return fmt.Errorf("%s", errWithOP.Message) + } + defer resp.Body.Close() + + telegramErr := telegramErrFunc(resp) + if telegramErr != nil { + return fmt.Errorf("%s", telegramErr.Message) + } + + return nil +} + +func splitTelegramMessageIntoParts(message string, partSize int) []string { + var parts []string + for len(message) > partSize { + parts = append(parts, message[:partSize]) + message = message[partSize:] + } + parts = append(parts, message) + + return parts +} + +func telegramErrFunc(resp *http.Response) *types.OpenAIError { + respMsg := &telegramResponse{} + err := json.NewDecoder(resp.Body).Decode(respMsg) + if err != nil { + return nil + } + + if respMsg.Ok { + return nil + } + + return &types.OpenAIError{ + Message: fmt.Sprintf("send msg err. err msg: %s", respMsg.Description), + Type: "telegram_error", + } +} diff --git a/common/notify/notifier.go b/common/notify/notifier.go new file mode 100644 index 00000000..c70c4eb3 --- /dev/null +++ b/common/notify/notifier.go @@ -0,0 +1,98 @@ +package notify + +import ( + "context" + "one-api/common" + "one-api/common/notify/channel" + + "github.com/spf13/viper" +) + +type Notifier interface { + Send(context.Context, string, string) error + Name() string +} + +func InitNotifier() { + InitEmailNotifier() + InitDingTalkNotifier() + InitLarkNotifier() + InitPushdeerNotifier() + InitTelegramNotifier() +} + +func InitEmailNotifier() { + if viper.GetBool("notify.email.disable") { + common.SysLog("email notifier disabled") + return + } + smtp_to := viper.GetString("notify.email.smtp_to") + emailNotifier := channel.NewEmail(smtp_to) + AddNotifiers(emailNotifier) + common.SysLog("email notifier enable") +} + +func InitDingTalkNotifier() { + access_token := viper.GetString("notify.dingtalk.token") + secret := viper.GetString("notify.dingtalk.secret") + keyWord := viper.GetString("notify.dingtalk.keyWord") + if access_token == "" || (secret == "" && keyWord == "") { + return + } + + var dingTalkNotifier Notifier + + if secret != "" { + dingTalkNotifier = channel.NewDingTalk(access_token, secret) + } else { + dingTalkNotifier = channel.NewDingTalkWithKeyWord(access_token, keyWord) + } + + AddNotifiers(dingTalkNotifier) + common.SysLog("dingtalk notifier enable") +} + +func InitLarkNotifier() { + access_token := viper.GetString("notify.lark.token") + secret := viper.GetString("notify.lark.secret") + keyWord := viper.GetString("notify.lark.keyWord") + if access_token == "" || (secret == "" && keyWord == "") { + return + } + + var larkNotifier Notifier + + if secret != "" { + larkNotifier = channel.NewLark(access_token, secret) + } else { + larkNotifier = channel.NewLarkWithKeyWord(access_token, keyWord) + } + + AddNotifiers(larkNotifier) + common.SysLog("lark notifier enable") +} + +func InitPushdeerNotifier() { + pushkey := viper.GetString("notify.pushdeer.pushkey") + if pushkey == "" { + return + } + + pushdeerNotifier := channel.NewPushdeer(pushkey, viper.GetString("notify.pushdeer.url")) + + AddNotifiers(pushdeerNotifier) + common.SysLog("pushdeer notifier enable") +} + +func InitTelegramNotifier() { + bot_token := viper.GetString("notify.telegram.bot_api_key") + chat_id := viper.GetString("notify.telegram.chat_id") + if bot_token == "" || chat_id == "" { + return + } + + telegramNotifier := channel.NewTelegram(bot_token, chat_id) + + AddNotifiers(telegramNotifier) + common.SysLog("telegram notifier enable") +} diff --git a/common/notify/notify.go b/common/notify/notify.go new file mode 100644 index 00000000..a6c9b9f1 --- /dev/null +++ b/common/notify/notify.go @@ -0,0 +1,35 @@ +package notify + +var notifyChannels = New() + +type Notify struct { + notifiers map[string]Notifier +} + +func (n *Notify) addChannel(channel Notifier) { + if channel != nil { + channelName := channel.Name() + if _, ok := n.notifiers[channelName]; ok { + return + } + n.notifiers[channelName] = channel + } +} + +func (n *Notify) addChannels(channel ...Notifier) { + for _, s := range channel { + n.addChannel(s) + } +} + +func New() *Notify { + notify := &Notify{ + notifiers: make(map[string]Notifier, 0), + } + + return notify +} + +func AddNotifiers(channel ...Notifier) { + notifyChannels.addChannels(channel...) +} diff --git a/common/notify/send.go b/common/notify/send.go new file mode 100644 index 00000000..6246af6d --- /dev/null +++ b/common/notify/send.go @@ -0,0 +1,30 @@ +package notify + +import ( + "context" + "fmt" + "one-api/common" +) + +func (n *Notify) Send(ctx context.Context, title, message string) { + if ctx == nil { + ctx = context.Background() + } + + for channelName, channel := range n.notifiers { + if channel == nil { + continue + } + err := channel.Send(ctx, title, message) + if err != nil { + common.LogError(ctx, fmt.Sprintf("%s err: %s", channelName, err.Error())) + } + } +} + +func Send(title, message string) { + //lint:ignore SA1029 reason: 需要使用该类型作为错误处理 + ctx := context.WithValue(context.Background(), common.RequestIdKey, "NotifyTask") + + notifyChannels.Send(ctx, title, message) +} diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go index ccc27eb8..f36601fe 100644 --- a/common/requester/http_requester.go +++ b/common/requester/http_requester.go @@ -24,6 +24,7 @@ type HTTPRequester struct { ErrorHandler HttpErrorHandler proxyAddr string Context context.Context + IsOpenAI bool } // NewHTTPRequester 创建一个新的 HTTPRequester 实例。 @@ -39,6 +40,7 @@ func NewHTTPRequester(proxyAddr string, errorHandler HttpErrorHandler) *HTTPRequ ErrorHandler: errorHandler, proxyAddr: proxyAddr, Context: context.Background(), + IsOpenAI: true, } } @@ -94,10 +96,14 @@ func (r *HTTPRequester) SendRequest(req *http.Request, response any, outputResp // 处理响应 if r.IsFailureStatusCode(resp) { - return nil, HandleErrorResp(resp, r.ErrorHandler) + return nil, HandleErrorResp(resp, r.ErrorHandler, r.IsOpenAI) } // 解析响应 + if response == nil { + return resp, nil + } + if outputResp { var buf bytes.Buffer tee := io.TeeReader(resp.Body, &buf) @@ -126,7 +132,7 @@ func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *type // 处理响应 if r.IsFailureStatusCode(resp) { - return nil, HandleErrorResp(resp, r.ErrorHandler) + return nil, HandleErrorResp(resp, r.ErrorHandler, r.IsOpenAI) } return resp, nil @@ -136,7 +142,7 @@ func (r *HTTPRequester) SendRequestRaw(req *http.Request) (*http.Response, *type func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, handlerPrefix HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) { // 如果返回的头是json格式 说明有错误 if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { - return nil, HandleErrorResp(resp, requester.ErrorHandler) + return nil, HandleErrorResp(resp, requester.ErrorHandler, requester.IsOpenAI) } stream := &streamReader[T]{ @@ -180,7 +186,7 @@ func (r *HTTPRequester) IsFailureStatusCode(resp *http.Response) bool { } // 处理错误响应 -func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types.OpenAIErrorWithStatusCode { +func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler, isPrefix bool) *types.OpenAIErrorWithStatusCode { openAIErrorWithStatusCode := &types.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, @@ -199,12 +205,19 @@ func HandleErrorResp(resp *http.Response, toOpenAIError HttpErrorHandler) *types if errorResponse != nil && errorResponse.Message != "" { openAIErrorWithStatusCode.OpenAIError = *errorResponse - openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: %s", openAIErrorWithStatusCode.OpenAIError.Message) + if isPrefix { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: %s", openAIErrorWithStatusCode.OpenAIError.Message) + } } } if openAIErrorWithStatusCode.OpenAIError.Message == "" { openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: bad response status code %d", resp.StatusCode) + if isPrefix { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("Provider API error: bad response status code %d", resp.StatusCode) + } else { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } } return openAIErrorWithStatusCode @@ -218,6 +231,12 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +func GetJsonHeaders() map[string]string { + return map[string]string{ + "Content-type": "application/json", + } +} + type Stringer interface { GetString() *string } diff --git a/common/stmp/email.go b/common/stmp/email.go new file mode 100644 index 00000000..8d513452 --- /dev/null +++ b/common/stmp/email.go @@ -0,0 +1,168 @@ +package stmp + +import ( + "fmt" + "one-api/common" + "strings" + + "github.com/wneessen/go-mail" +) + +type StmpConfig struct { + Host string + Port int + Username string + Password string + From string +} + +func NewStmp(host string, port int, username string, password string, from string) *StmpConfig { + if from == "" { + from = username + } + + return &StmpConfig{ + Host: host, + Port: port, + Username: username, + Password: password, + From: from, + } +} + +func (s *StmpConfig) Send(to, subject, body string) error { + message := mail.NewMsg() + message.From(s.From) + message.To(to) + message.Subject(subject) + message.SetGenHeader("References", s.getReferences()) + message.SetBodyString(mail.TypeTextHTML, body) + message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", common.Version)) + + client, err := mail.NewClient( + s.Host, + mail.WithPort(s.Port), + mail.WithUsername(s.Username), + mail.WithPassword(s.Password), + mail.WithSMTPAuth(mail.SMTPAuthPlain)) + + if err != nil { + return err + } + + switch s.Port { + case 465: + client.SetSSL(true) + case 587: + client.SetTLSPolicy(mail.TLSMandatory) + } + + if err := client.DialAndSend(message); err != nil { + return err + } + + return nil +} + +func (s *StmpConfig) getReferences() string { + froms := strings.Split(s.From, "@") + return fmt.Sprintf("<%s.%s@%s>", froms[0], common.GetUUID(), froms[1]) +} + +func (s *StmpConfig) Render(to, subject, content string) error { + body := getDefaultTemplate(content) + + return s.Send(to, subject, body) +} + +func GetSystemStmp() (*StmpConfig, error) { + if common.SMTPServer == "" || common.SMTPPort == 0 || common.SMTPAccount == "" || common.SMTPToken == "" { + return nil, fmt.Errorf("SMTP 信息未配置") + } + + return NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom), nil +} + +func SendPasswordResetEmail(userName, email, link string) error { + stmp, err := GetSystemStmp() + + if err != nil { + return err + } + + contentTemp := `

Hi %s,

+

+ 您正在进行密码重置。点击下方按钮以重置密码。 +

+ +

+ 重置密码 +

+ +

+ 如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开
%s +

+

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

` + + subject := fmt.Sprintf("%s密码重置", common.SystemName) + content := fmt.Sprintf(contentTemp, userName, link, link, common.VerificationValidMinutes) + + return stmp.Render(email, subject, content) +} + +func SendVerificationCodeEmail(email, code string) error { + stmp, err := GetSystemStmp() + + if err != nil { + return err + } + + contentTemp := ` +

+ 您正在进行邮箱验证。您的验证码为: +

+ +

+ %s +

+ +

+ 验证码 %d 分钟内有效,如果不是本人操作,请忽略。 +

` + + subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes) + + return stmp.Render(email, subject, content) +} + +func SendQuotaWarningCodeEmail(userName, email string, quota int, noMoreQuota bool) error { + stmp, err := GetSystemStmp() + + if err != nil { + return err + } + + contentTemp := `

Hi %s,

+

+ %s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。 +

+ +

+ 点击充值 +

+ +

+ 如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开
%s +

` + + subject := "您的额度即将用尽" + if noMoreQuota { + subject = "您的额度已用尽" + } + topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + + content := fmt.Sprintf(contentTemp, userName, subject, quota, topUpLink, topUpLink) + + return stmp.Render(email, subject, content) +} diff --git a/common/stmp/template.go b/common/stmp/template.go new file mode 100644 index 00000000..8ac44147 --- /dev/null +++ b/common/stmp/template.go @@ -0,0 +1,110 @@ +package stmp + +import ( + "one-api/common" +) + +func getLogo() string { + if common.Logo == "" { + return "" + } + return ` + + + + ` +} + +func getSystemName() string { + if common.SystemName == "" { + return "One API" + } + + return common.SystemName +} + +func getDefaultTemplate(content string) string { + return ` + + + + + + + + +
+ ` + getLogo() + ` + + + + +
+ ` + content + ` +
+ + + + + +
+ + ` +} diff --git a/config.example.yaml b/config.example.yaml index 2ed668c4..81a52065 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -42,4 +42,22 @@ data_gym_cache_dir: "" # Telegram设置 tg: bot_api_key: "" # 你的 Telegram bot 的 API 密钥 - webhook_secret: "" # 你的 webhook 密钥。你可以自定义这个密钥。如果设置了这个密钥,将使用webhook的方式接收消息,否则使用轮询(Polling)的方式。 \ No newline at end of file + webhook_secret: "" # 你的 webhook 密钥。你可以自定义这个密钥。如果设置了这个密钥,将使用webhook的方式接收消息,否则使用轮询(Polling)的方式。 +notify: # 通知设置, 配置了几个通知方式,就会同时发送几次通知 如果不需要通知,可以删除这个配置 + email: # 邮件通知 (具体stmp配置在后台设置) + disable: false # 是否禁用邮件通知 + smtp_to: "" # 收件人地址 (可空,如果为空则使用超级管理员邮箱) + dingTalk: # 钉钉机器人通知 + token: "" # webhook 地址最后一串字符 + secret: "" # 密钥 (secret/keyWord 二选一) + keyWord: "" # 关键字 (secret/keyWord 二选一) + lark: # 飞书机器人通知 + token: "" # webhook 地址最后一串字符 + secret: "" # 密钥 (secret/keyWord 二选一) + keyWord: "" # 关键字 (secret/keyWord 二选一) + pushdeer: # pushdeer 通知 + url: "https://api2.pushdeer.com" # pushdeer地址 (可空,如果自建需填写) + pushkey: "" # pushkey + telegram: # Telegram 通知 + bot_api_key: "" # 你的 Telegram bot 的 API 密钥 + chat_id: "" # 你的 Telegram chat_id \ No newline at end of file diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 51d07323..c6d5c511 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -122,7 +122,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - DisableChannel(channel.Id, channel.Name, "余额不足") + DisableChannel(channel.Id, channel.Name, "余额不足", true) } } time.Sleep(common.RequestInterval) diff --git a/controller/channel-test.go b/controller/channel-test.go index ea4ab886..8d1fda61 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "one-api/common" + "one-api/common/notify" "one-api/model" "one-api/providers" providers_base "one-api/providers/base" @@ -130,28 +131,7 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, common.RootUserEmail, content) - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } -} - -// enable & notify -func enableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) -} - -func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } +func testAllChannels(isNotify bool) error { testAllChannelsLock.Lock() if testAllChannelsRunning { testAllChannelsLock.Unlock() @@ -168,33 +148,63 @@ func testAllChannels(notify bool) error { disableThreshold = 10000000 // a impossible value } go func() { + var sendMessage string for _, channel := range channels { + time.Sleep(common.RequestInterval) + isChannelEnabled := channel.Status == common.ChannelStatusEnabled + sendMessage += fmt.Sprintf("**通道 %s (#%d) [%s]** : \n\n", channel.Name, channel.Id, channel.StatusToStr()) tik := time.Now() err, openaiErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() - if milliseconds > disableThreshold { - err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) - DisableChannel(channel.Id, channel.Name, err.Error()) - } - if isChannelEnabled && ShouldDisableChannel(openaiErr, -1) { - DisableChannel(channel.Id, channel.Name, err.Error()) - } - if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { - enableChannel(channel.Id, channel.Name) + // 通道为禁用状态,并且还是请求错误 或者 响应时间超过阈值 直接跳过,也不需要更新响应时间。 + if !isChannelEnabled { + if err != nil { + sendMessage += fmt.Sprintf("- 测试报错: %s \n\n- 无需改变状态,跳过\n\n", err.Error()) + continue + } + if milliseconds > disableThreshold { + sendMessage += fmt.Sprintf("- 响应时间 %.2fs 超过阈值 %.2fs \n\n- 无需改变状态,跳过\n\n", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) + continue + } + // 如果已被禁用,但是请求成功,需要判断是否需要恢复 + // 手动禁用的通道,不会自动恢复 + if shouldEnableChannel(err, openaiErr) { + if channel.Status == common.ChannelStatusAutoDisabled { + EnableChannel(channel.Id, channel.Name, false) + sendMessage += "- 已被启用 \n\n" + } else { + sendMessage += "- 手动禁用的通道,不会自动恢复 \n\n" + } + } + } else { + // 如果通道启用状态,但是返回了错误 或者 响应时间超过阈值,需要判断是否需要禁用 + if milliseconds > disableThreshold { + sendMessage += fmt.Sprintf("- 响应时间 %.2fs 超过阈值 %.2fs \n\n- 禁用\n\n", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) + DisableChannel(channel.Id, channel.Name, err.Error(), false) + continue + } + + if ShouldDisableChannel(openaiErr, -1) { + sendMessage += fmt.Sprintf("- 已被禁用,原因:%s\n\n", err.Error()) + DisableChannel(channel.Id, channel.Name, err.Error(), false) + continue + } + + if err != nil { + sendMessage += fmt.Sprintf("- 测试报错: %s \n\n", err.Error()) + continue + } } channel.UpdateResponseTime(milliseconds) - time.Sleep(common.RequestInterval) + sendMessage += fmt.Sprintf("- 测试完成,耗时 %.2fs\n\n", float64(milliseconds)/1000.0) } testAllChannelsLock.Lock() testAllChannelsRunning = false testAllChannelsLock.Unlock() - if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } + if isNotify { + notify.Send("通道测试完成", sendMessage) } }() return nil diff --git a/controller/common.go b/controller/common.go index 90349b2a..bad43a3e 100644 --- a/controller/common.go +++ b/controller/common.go @@ -4,8 +4,10 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/notify" "one-api/model" "one-api/types" + "strings" "github.com/gin-gonic/gin" ) @@ -36,18 +38,62 @@ func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool { return true } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { + return true + } + + if strings.Contains(err.Message, "credit") { + return true + } + if strings.Contains(err.Message, "balance") { + return true + } + + if strings.Contains(err.Message, "Access denied") { return true } return false + } // disable & notify -func DisableChannel(channelId int, channelName string, reason string) { +func DisableChannel(channelId int, channelName string, reason string, sendNotify bool) { model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + if !sendNotify { + return + } + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) + notify.Send(subject, content) +} + +// enable & notify +func EnableChannel(channelId int, channelName string, sendNotify bool) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + if !sendNotify { + return + } + + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notify.Send(subject, content) } func RelayNotImplemented(c *gin.Context) { diff --git a/controller/misc.go b/controller/misc.go index 810e02d6..3eae703d 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/stmp" "one-api/common/telegram" "one-api/model" "strings" @@ -109,11 +110,7 @@ func SendEmailVerification(c *gin.Context) { } code := common.GenerateVerificationCode(6) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) - subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) - content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ - "

您的验证码为: %s

"+ - "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, code, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := stmp.SendVerificationCodeEmail(email, code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -136,22 +133,29 @@ func SendPasswordResetEmail(c *gin.Context) { }) return } - if !model.IsEmailAlreadyTaken(email) { + + user := &model.User{ + Email: email, + } + + if err := user.FillUserByEmail(); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该邮箱地址未注册", }) return } + + userName := user.DisplayName + if userName == "" { + userName = user.Username + } + code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) - subject := fmt.Sprintf("%s密码重置", common.SystemName) - content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ - "

点击 此处 进行密码重置。

"+ - "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ - "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, link, link, common.VerificationValidMinutes) - err := common.SendEmail(subject, email, content) + err := stmp.SendPasswordResetEmail(userName, email, link) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -173,6 +177,14 @@ type PasswordResetRequest struct { func ResetPassword(c *gin.Context) { var req PasswordResetRequest err := json.NewDecoder(c.Request.Body).Decode(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if req.Email == "" || req.Token == "" { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/user.go b/controller/user.go index 42faad3e..7f69998f 100644 --- a/controller/user.go +++ b/controller/user.go @@ -696,9 +696,7 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { - common.RootUserEmail = email - } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/go.mod b/go.mod index 290c82dc..da655611 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -42,6 +43,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.18.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/wneessen/go-mail v0.4.1 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect diff --git a/go.sum b/go.sum index a426c315..a1fa358e 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0 h1:4gjrh/PN2MuWCCElk8/I4OCKRKWCCo2zEct3VKCbibU= +github.com/gomarkdown/markdown v0.0.0-20240328165702-4d01890c35c0/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -262,6 +264,8 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/wneessen/go-mail v0.4.1 h1:m2rSg/sc8FZQCdtrV5M8ymHYOFrC6KJAQAIcgrXvqoo= +github.com/wneessen/go-mail v0.4.1/go.mod h1:zxOlafWCP/r6FEhAaRgH4IC1vg2YXxO0Nar9u0IScZ8= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/main.go b/main.go index 02417e03..df5478df 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/common/config" + "one-api/common/notify" "one-api/common/requester" "one-api/common/telegram" "one-api/controller" @@ -46,6 +47,7 @@ func main() { telegram.InitTelegramBot() controller.InitMidjourneyTask() + notify.InitNotifier() initHttpServer() } diff --git a/model/channel.go b/model/channel.go index 7dc58115..c7cd9a75 100644 --- a/model/channel.go +++ b/model/channel.go @@ -265,6 +265,19 @@ func (channel *Channel) Delete() error { return err } +func (channel *Channel) StatusToStr() string { + switch channel.Status { + case common.ChannelStatusEnabled: + return "启用" + case common.ChannelStatusAutoDisabled: + return "自动禁用" + case common.ChannelStatusManuallyDisabled: + return "手动禁用" + } + + return "禁用" +} + func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { diff --git a/model/main.go b/model/main.go index 61e8c9b6..5e682af2 100644 --- a/model/main.go +++ b/model/main.go @@ -22,6 +22,7 @@ func SetupDB() { common.FatalLog("failed to initialize database: " + err.Error()) } ChannelGroup.Load() + common.RootUserEmail = GetRootUserEmail() if viper.GetBool("BATCH_UPDATE_ENABLED") { common.BatchUpdateEnabled = true diff --git a/model/token.go b/model/token.go index 11f10d1f..689ec49e 100644 --- a/model/token.go +++ b/model/token.go @@ -2,8 +2,8 @@ package model import ( "errors" - "fmt" "one-api/common" + "one-api/common/stmp" "gorm.io/gorm" ) @@ -114,15 +114,13 @@ func GetTokenById(id int) (*Token, error) { } func (token *Token) Insert() error { - var err error - err = DB.Create(token).Error + err := DB.Create(token).Error return err } // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { - var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err := DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error return err } @@ -132,8 +130,7 @@ func (token *Token) SelectUpdate() error { } func (token *Token) Delete() error { - var err error - err = DB.Delete(token).Error + err := DB.Delete(token).Error return err } @@ -228,26 +225,35 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { } func sendQuotaWarningEmail(userId int, userQuota int, noMoreQuota bool) { - email, err := GetUserEmail(userId) - if err != nil { + user := User{Id: userId} + + if err := user.FillUserById(); err != nil { common.SysError("failed to fetch user email: " + err.Error()) + return } - prompt := "您的额度即将用尽" - if noMoreQuota { - prompt = "您的额度已用尽" + + if user.Email == "" { + common.SysError("user email is empty") + return } - if email != "" { - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) - err = common.SendEmail(prompt, email, - fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) - if err != nil { - common.SysError("failed to send email" + err.Error()) - } + + userName := user.DisplayName + if userName == "" { + userName = user.Username + } + + err := stmp.SendQuotaWarningCodeEmail(userName, user.Email, userQuota, noMoreQuota) + + if err != nil { + common.SysError("failed to send email" + err.Error()) } } func PostConsumeTokenQuota(tokenId int, quota int) (err error) { token, err := GetTokenById(tokenId) + if err != nil { + return err + } if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) } else { diff --git a/model/user.go b/model/user.go index 5d93fdf6..05a47bb0 100644 --- a/model/user.go +++ b/model/user.go @@ -149,6 +149,11 @@ func (user *User) Update(updatePassword bool) error { } } err = DB.Model(user).Updates(user).Error + + if err == nil && user.Role == common.RoleRootUser { + common.RootUserEmail = user.Email + } + return err } @@ -201,7 +206,14 @@ func (user *User) FillUserById() error { if user.Id == 0 { return errors.New("id 为空!") } - DB.Where(User{Id: user.Id}).First(user) + + result := DB.Where(User{Id: user.Id}).First(user) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return errors.New("没有找到用户!") + } + return result.Error + } return nil } @@ -209,7 +221,14 @@ func (user *User) FillUserByEmail() error { if user.Email == "" { return errors.New("email 为空!") } - DB.Where(User{Email: user.Email}).First(user) + + result := DB.Where(User{Email: user.Email}).First(user) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return errors.New("没有找到用户!") + } + return result.Error + } return nil } diff --git a/providers/bedrock/stream_reader.go b/providers/bedrock/stream_reader.go index 835753f7..03afd3ab 100644 --- a/providers/bedrock/stream_reader.go +++ b/providers/bedrock/stream_reader.go @@ -112,7 +112,7 @@ func (stream *streamReader[T]) deserializeEventMessage(msg *eventstream.Message) func RequestStream[T any](resp *http.Response, handlerPrefix requester.HandlerPrefix[T]) (*streamReader[T], *types.OpenAIErrorWithStatusCode) { // 如果返回的头是json格式 说明有错误 if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { - return nil, requester.HandleErrorResp(resp, requestErrorHandle) + return nil, requester.HandleErrorResp(resp, requestErrorHandle, true) } stream := &streamReader[T]{ diff --git a/providers/openai/speech.go b/providers/openai/speech.go index 1df76cad..5d3a6f65 100644 --- a/providers/openai/speech.go +++ b/providers/openai/speech.go @@ -22,7 +22,7 @@ func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http. } if resp.Header.Get("Content-Type") == "application/json" { - return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler) + return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler, p.Requester.IsOpenAI) } p.Usage.TotalTokens = p.Usage.PromptTokens diff --git a/relay/common.go b/relay/common.go index c3567999..7b2d9ffd 100644 --- a/relay/common.go +++ b/relay/common.go @@ -197,6 +197,6 @@ func shouldRetry(c *gin.Context, statusCode int) bool { func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) { common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message)) if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) { - controller.DisableChannel(channelId, channelName, err.Message) + controller.DisableChannel(channelId, channelName, err.Message, true) } } diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js index 2814ff2b..fd75c738 100644 --- a/web/src/views/Channel/component/EditModal.js +++ b/web/src/views/Channel/component/EditModal.js @@ -171,7 +171,6 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => { const submit = async (values, { setErrors, setStatus, setSubmitting }) => { setSubmitting(true); - console.log(values); values = trims(values); if (values.base_url && values.base_url.endsWith('/')) { values.base_url = values.base_url.slice(0, values.base_url.length - 1); diff --git a/web/src/views/Channel/component/TableRow.js b/web/src/views/Channel/component/TableRow.js index 3939a71e..3f263df4 100644 --- a/web/src/views/Channel/component/TableRow.js +++ b/web/src/views/Channel/component/TableRow.js @@ -1,5 +1,5 @@ import PropTypes from 'prop-types'; -import { useState } from 'react'; +import { useState, useEffect } from 'react'; import { showInfo, showError, renderNumber } from 'utils/common'; import { API } from 'utils/api'; @@ -76,6 +76,19 @@ const StyledMenu = styled((props) => ( } })); +function statusInfo(status) { + switch (status) { + case 1: + return '启用'; + case 2: + return '手动'; + case 3: + return '自动'; + default: + return '未知'; + } +} + export default function ChannelTableRow({ item, manageChannel, handleOpenModal, setModalChannelId }) { const [open, setOpen] = useState(null); const [openTest, setOpenTest] = useState(false); @@ -189,6 +202,14 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, await manageChannel(item.id, 'delete', ''); }; + useEffect(() => { + setStatusSwitch(item.status); + setPriority(item.priority); + setWeight(item.weight); + setItemBalance(item.balance); + setResponseTimeData({ test_time: item.test_time, response_time: item.response_time }); + }, [item]); + return ( <> @@ -219,6 +240,7 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, + {statusInfo(statusSwitch)} diff --git a/web/src/views/Channel/index.js b/web/src/views/Channel/index.js index da979780..ea9d508b 100644 --- a/web/src/views/Channel/index.js +++ b/web/src/views/Channel/index.js @@ -335,7 +335,7 @@ export default function ChannelPage() { 搜索