[go: up one dir, main page]

Skip to content

Commit

Permalink
feat: 初步兼容midjourney-proxy-plus
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Mar 13, 2024
1 parent 95d8059 commit 37c0c8e
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 153 deletions.
2 changes: 1 addition & 1 deletion common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ const (
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeOpenAISB = 5
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
Expand Down
16 changes: 16 additions & 0 deletions constant/midjourney.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package constant

const (
MjErrorUnknown = 5
MjRequestError = 4
)

const (
MjActionImagine = "IMAGINE"
MjActionDescribe = "DESCRIBE"
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
MjActionInPaint = "INPAINT"
MjActionInPaintPre = "INPAINT_PRE"
)
4 changes: 2 additions & 2 deletions controller/channel-billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
Expand Down
18 changes: 13 additions & 5 deletions controller/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/model"
relay2 "one-api/relay"
"one-api/service"
"strconv"
"strings"
Expand Down Expand Up @@ -75,11 +75,11 @@ import (
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
var responseItem MidjourneyDto
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
Expand Down Expand Up @@ -228,12 +228,16 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []relay2.Midjourney
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
Expand All @@ -259,6 +263,10 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
Expand Down Expand Up @@ -286,7 +294,7 @@ func UpdateMidjourneyTaskBulk() {
}
}

func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}
Expand Down
37 changes: 16 additions & 21 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ func Relay(c *gin.Context) {

func RelayMidjourney(c *gin.Context) {
relayMode := relayconstant.RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/action") {
// midjourney plus
relayMode = relayconstant.RelayModeMidjourneyAction
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/modal") {
// midjourney plus
relayMode = relayconstant.RelayModeMidjourneyModal
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = relayconstant.RelayModeMidjourneyBlend
Expand All @@ -86,35 +92,24 @@ func RelayMidjourney(c *gin.Context) {
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
//case relayconstant.RelayModeMidjourneyModal:
// err = relay.RelayMidjournneyModal(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(429, gin.H{
"error": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
})
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(429, gin.H{
"error": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
//if shouldDisableChannel(&err.Error) {
// channelId := c.GetInt("channel_id")
// channelName := c.GetString("channel_name")
// disableChannel(channelId, channelName, err.Result)
//};''''''''''''''''''''''''''''''''
}
}

Expand Down
52 changes: 52 additions & 0 deletions dto/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package dto

type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
MaskBase64 string `json:"maskBase64"`
}

type MidjourneyResponse struct {
Expand All @@ -17,3 +20,52 @@ type MidjourneyResponse struct {
Properties interface{} `json:"properties"`
Result string `json:"result"`
}

type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
Buttons any `json:"buttons"`
MaskBase64 string `json:"maskBase64"`
}

type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}

type ActionButton struct {
CustomId any `json:"customId"`
Emoji any `json:"emoji"`
Label any `json:"label"`
Type any `json:"type"`
Style any `json:"style"`
}
6 changes: 0 additions & 6 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,6 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
Expand Down
1 change: 1 addition & 0 deletions model/midjourney.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Midjourney struct {
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
}

// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
Expand Down
2 changes: 2 additions & 0 deletions relay/constant/relay_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
)

func Path2RelayMode(path string) int {
Expand Down
Loading

0 comments on commit 37c0c8e

Please sign in to comment.