[go: up one dir, main page]

Skip to content

Commit

Permalink
feat: 完善模型价格获取逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed May 15, 2024
1 parent ff044de commit 93858c3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
16 changes: 8 additions & 8 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ func init() {
})
}
openAIModelsMap = make(map[string]dto.OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
for _, aiModel := range openAIModels {
openAIModelsMap[aiModel.Id] = aiModel
}
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
Expand Down Expand Up @@ -174,8 +174,8 @@ func DashboardListModels(c *gin.Context) {

func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
if aiModel, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, aiModel)
} else {
openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Expand All @@ -191,12 +191,12 @@ func RetrieveModel(c *gin.Context) {

func GetPricing(c *gin.Context) {
userId := c.GetInt("id")
user, _ := model.GetUserById(userId, true)
group, err := model.CacheGetUserGroup(userId)
groupRatio := common.GetGroupRatio("default")
if user != nil {
groupRatio = common.GetGroupRatio(user.Group)
if err != nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(user, openAIModels)
pricing := model.GetPricing(group)
c.JSON(200, gin.H{
"success": true,
"data": pricing,
Expand Down
10 changes: 5 additions & 5 deletions model/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ var (
updatePricingLock sync.Mutex
)

func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing {
func GetPricing(group string) []dto.ModelPricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()

if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing(openAIModels)
updatePricing()
}
if user != nil {
if group != "" {
userPricingMap := make([]dto.ModelPricing, 0)
models := GetGroupModels(user.Group)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
pricing.Available = false
Expand All @@ -34,7 +34,7 @@ func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing
return pricingMap
}

func updatePricing(openAIModels []dto.OpenAIModels) {
func updatePricing() {
//modelRatios := common.GetModelRatios()
enabledModels := GetEnabledModels()
allModels := make(map[string]int)
Expand Down
2 changes: 1 addition & 1 deletion router/api-router.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/about", controller.GetAbout)
//apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/pricing", middleware.CriticalRateLimit(), middleware.TryUserAuth(), controller.GetPricing)
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
Expand Down

0 comments on commit 93858c3

Please sign in to comment.