[go: up one dir, main page]

Skip to content

Commit

Permalink
Merge branch 'main' into telegram-login
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 authored Mar 3, 2024
2 parents 699fe25 + 54088bc commit 02d5a5f
Show file tree
Hide file tree
Showing 27 changed files with 233 additions and 116 deletions.
14 changes: 10 additions & 4 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ import (
"github.com/google/uuid"
)

var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
// Pay Settings

var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1

var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
Expand All @@ -29,6 +34,7 @@ var DrawingEnabled = true
var DataExportEnabled = true
var DataExportInterval = 5 // unit: minute
var DataExportDefaultTime = "hour" // unit: minute
var DefaultCollapseSidebar = false // default value of collapse sidebar

// Any options with "Secret", "Token" in its key won't be return by GetOptions

Expand Down
5 changes: 4 additions & 1 deletion common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ var ModelRatio = map[string]float64{
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
Expand Down
3 changes: 2 additions & 1 deletion controller/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ func FixChannelsAbilities(c *gin.Context) {
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group)
channels, err := model.SearchChannels(keyword, group, modelKeyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
Expand Down
2 changes: 2 additions & 0 deletions controller/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func GetStatus(c *gin.Context) {
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"price": common.Price,
"min_topup": common.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
Expand All @@ -40,6 +41,7 @@ func GetStatus(c *gin.Context) {
"enable_drawing": common.DrawingEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
},
})
Expand Down
7 changes: 7 additions & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ func ListModels(c *gin.Context) {
})
}

func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
})
}

func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
Expand Down
16 changes: 8 additions & 8 deletions controller/topup.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"one-api/common"
"one-api/model"
"one-api/service"
"strconv"
"time"
)
Expand Down Expand Up @@ -55,14 +56,14 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
if req.Amount < 1 {
c.JSON(200, gin.H{"message": "充值金额不能小于1", "data": 10})
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return
}

id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
amount := GetAmount(float64(req.Amount), *user)
payMoney := GetAmount(float64(req.Amount), *user)

var payType epay.PurchaseType
if req.PaymentMethod == "zfb" {
Expand All @@ -72,11 +73,10 @@ func RequestEpay(c *gin.Context) {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}

callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
notifyUrl, _ := url.Parse(common.ServerAddress + "/api/user/epay/notify")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
payMoney := amount
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
Expand Down Expand Up @@ -169,8 +169,8 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < 1 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额不能小于1"})
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
return
}
id := c.GetInt("id")
Expand Down
17 changes: 10 additions & 7 deletions model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,24 +291,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
}

// 平滑系数
smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
totalWeight += channel.GetWeight() + smoothingFactor
}

if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}
//if totalWeight == 0 {
// // If all weights are 0, select a channel randomly
// return channels[rand.Intn(endIdx)], nil
//}

// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)

// Find a channel based on its weight
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 {
return channel, nil
}
}
Expand Down
34 changes: 26 additions & 8 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,39 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
return channels, err
}

func SearchChannels(keyword string, group string) (channels []*Channel, err error) {
func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"

// 如果是 PostgreSQL,使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}

// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol)

// 构造WHERE子句
var whereClause string
var args []interface{}
if group != "" {
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Omit("key").Where("(id = ? or name LIKE ? or "+keyCol+" = ?) and "+groupCol+" LIKE ?", common.String2Int(keyword), keyword+"%", keyword, "%"+group+"%").Find(&channels).Error
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
} else {
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
}
return channels, err

// 执行查询
err := baseQuery.Where(whereClause, args...).Find(&channels).Error
if err != nil {
return nil, err
}
return channels, nil
}

func GetChannelById(id int, selectAll bool) (*Channel, error) {
Expand Down
11 changes: 10 additions & 1 deletion model/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ func InitOptionMap() {
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
Expand All @@ -85,6 +87,7 @@ func InitOptionMap() {
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)

common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
Expand Down Expand Up @@ -141,7 +144,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") {
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
Expand Down Expand Up @@ -176,6 +179,8 @@ func updateOptionMap(key string, value string) (err error) {
common.DrawingEnabled = boolValue
case "DataExportEnabled":
common.DataExportEnabled = boolValue
case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue
}
}
switch key {
Expand All @@ -196,12 +201,16 @@ func updateOptionMap(key string, value string) (err error) {
common.ServerAddress = value
case "PayAddress":
common.PayAddress = value
case "CustomCallbackAddress":
common.CustomCallbackAddress = value
case "EpayId":
common.EpayId = value
case "EpayKey":
common.EpayKey = value
case "Price":
common.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
common.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
Expand Down
4 changes: 2 additions & 2 deletions relay/channel/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openaiStreamHandler(c, resp, info.RelayMode)
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
Expand Down
4 changes: 2 additions & 2 deletions relay/channel/openai/relay-openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"time"
)

func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
Expand Down Expand Up @@ -111,7 +111,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String()
}

func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse dto.TextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zhipu_v4
package zhipu_4v

import (
"errors"
Expand All @@ -8,7 +8,9 @@ import (
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)

type Adaptor struct {
Expand Down Expand Up @@ -41,9 +43,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = zhipuHandler(c, resp)
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package zhipu_v4
package zhipu_4v

var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo",
}

var ChannelName = "zhipu_v4"
var ChannelName = "zhipu_4v"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zhipu_v4
package zhipu_4v

import (
"one-api/dto"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package zhipu_v4
package zhipu_4v

import (
"bufio"
Expand Down
1 change: 1 addition & 0 deletions relay/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
}
}
relayInfo.IsStream = textRequest.Stream
relayInfo.UpstreamModelName = textRequest.Model
return textRequest, nil
}

Expand Down
4 changes: 2 additions & 2 deletions relay/relay_adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_v4"
"one-api/relay/channel/zhipu_4v"
"one-api/relay/constant"
)

Expand All @@ -38,7 +38,7 @@ func GetAdaptor(apiType int) channel.Adaptor {
case constant.APITypeZhipu:
return &zhipu.Adaptor{}
case constant.APITypeZhipu_v4:
return &zhipu_v4.Adaptor{}
return &zhipu_4v.Adaptor{}
}
return nil
}
2 changes: 1 addition & 1 deletion router/api-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func SetApiRouter(router *gin.Engine) {
{
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/models", controller.ListModels)
channelRoute.GET("/models", controller.ChannelListModels)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
Expand Down
Loading

0 comments on commit 02d5a5f

Please sign in to comment.