[go: up one dir, main page]

Skip to content

Commit

Permalink
feat: 初步重构
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Feb 28, 2024
1 parent 9b42147 commit 5b18cd6
Show file tree
Hide file tree
Showing 67 changed files with 2,648 additions and 2,245 deletions.
2 changes: 1 addition & 1 deletion common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func StringsContains(strs []string, str string) bool {
return false
}

// []byte only read, panic on append
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
Expand Down
5 changes: 3 additions & 2 deletions controller/channel-billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/service"
"strconv"
"time"

Expand Down Expand Up @@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := httpClient.Do(req)
res, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -310,7 +311,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
service.DisableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(common.RequestInterval)
Expand Down
164 changes: 65 additions & 99 deletions controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,99 +5,95 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/model"
relaychannel "one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"strconv"
"sync"
"time"

"github.com/gin-gonic/gin"
)

func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, request.Model))
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
if request.Model == "" {
request.Model = "gpt-35-turbo"
}
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
}
}
baseUrl := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseUrl = channel.GetBaseURL()
}
requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type)
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
meta := relaycommon.GenRelayInfo(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := relaychannel.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
if testModel == "" {
testModel = adaptor.GetModelList()[0]
}
request := buildTestRequest()

if channel.Type == common.ChannelTypeAzure {
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
}
adaptor.Init(meta, *request)

jsonData, err := json.Marshal(request)
request.Model = testModel
meta.UpstreamModelName = testModel
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
if response.Usage.CompletionTokens == 0 {
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}

func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
func buildTestRequest() *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later
MaxTokens: 1,
}
content, _ := json.Marshal("hi")
testMessage := Message{
testMessage := dto.Message{
Role: "user",
Content: content,
}
Expand All @@ -114,7 +110,6 @@ func TestChannel(c *gin.Context) {
})
return
}
testModel := c.Query("model")
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
Expand All @@ -123,12 +118,9 @@ func TestChannel(c *gin.Context) {
})
return
}
testRequest := buildTestRequest()
if testModel != "" {
testRequest.Model = testModel
}
testModel := c.Query("model")
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
err, _ = testChannel(channel, testModel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
Expand All @@ -152,31 +144,6 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false

// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}

func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}

func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
Expand All @@ -192,7 +159,6 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
Expand All @@ -201,7 +167,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
err, openaiErr := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()

Expand All @@ -218,11 +184,11 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban {
disableChannel(channel.Id, channel.Name, err.Error())
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
service.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
Expand Down
10 changes: 6 additions & 4 deletions controller/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/controller/relay"
"one-api/model"
relay2 "one-api/relay"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -63,7 +65,7 @@ import (
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -221,7 +223,7 @@ func UpdateMidjourneyTaskBulk() {
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
resp, err := relay.httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
Expand All @@ -231,7 +233,7 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
var responseItems []relay2.Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
Expand Down Expand Up @@ -284,7 +286,7 @@ func UpdateMidjourneyTaskBulk() {
}
}

func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
if oldTask.Code != 1 {
return true
}
Expand Down
1 change: 0 additions & 1 deletion controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controller

import (
"fmt"

"github.com/gin-gonic/gin"
)

Expand Down
Loading

0 comments on commit 5b18cd6

Please sign in to comment.