[go: up one dir, main page]

Skip to content

Commit

Permalink
optimize: MJ 部分调整、优化
Browse files Browse the repository at this point in the history
MJ
增加simple-change、list接口,
变换和重试操作区别出来,价格与绘图一样
优化图片返回
  • Loading branch information
xyfacai committed Jan 1, 2024
1 parent 89dd0e0 commit 5c747df
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 56 deletions.
8 changes: 7 additions & 1 deletion common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"midjourney": 50,
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
Expand Down Expand Up @@ -80,6 +80,12 @@ var ModelRatio = map[string]float64{

var ModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
}

func ModelPrice2JSONString() string {
Expand Down
229 changes: 177 additions & 52 deletions controller/relay-mj.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type MidjourneyWithoutStatus struct {

func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByMJId(taskId)
midjourneyTask := model.GetByOnlyMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
Expand All @@ -71,14 +71,27 @@ func RelayMidjourneyImage(c *gin.Context) {
})
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
c.JSON(resp.StatusCode, gin.H{
"error": string(responseBody),
})
return
}
c.Header("Content-Type", "image/jpeg")
//c.HeaderBar("Content-Length", string(rune(len(data))))
c.Data(http.StatusOK, "image/jpeg", data)
// 从Content-Type头获取MIME类型
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// 如果无法确定内容类型,则默认为jpeg
contentType = "image/jpeg"
}
// 设置响应的内容类型
c.Writer.Header().Set("Content-Type", contentType)
// 将图片流式传输到响应体
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println("Failed to stream image:", err)
}
return
}

func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
Expand All @@ -92,7 +105,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
Result: "",
}
}
midjourneyTask := model.GetByMJId(midjRequest.MjId)
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &MidjourneyResponse{
Code: 4,
Expand Down Expand Up @@ -121,16 +134,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
return nil
}

func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
taskId := c.Param("id")
originTask := model.GetByMJId(taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
var midjourneyTask Midjourney
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
Expand All @@ -150,14 +154,65 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
jsonMap, err := json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
return
}

func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
userId := c.GetInt("id")
var err error
var respBody []byte
switch relayMode {
case RelayModeMidjourneyTaskFetch:
taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
midjourneyTask := getMidjourneyTaskModel(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
case RelayModeMidjourneyTaskFetchByCondition:
var condition = struct {
IDs []string `json:"ids"`
}{}
err = c.BindJSON(&condition)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
var tasks []Midjourney
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := getMidjourneyTaskModel(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]Midjourney, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))

c.Writer.Header().Set("Content-Type", "application/json")

_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return &MidjourneyResponse{
Code: 4,
Expand All @@ -167,6 +222,18 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
return nil
}

const (
// type 1 根据 mode 价格不同
MJSubmitActionImagine = "IMAGINE"
MJSubmitActionVariation = "VARIATION" //变换
MJSubmitActionBlend = "BLEND" //混图

MJSubmitActionReroll = "REROLL" //重新生成
// type 2 固定价格
MJSubmitActionDescribe = "DESCRIBE"
MJSubmitActionUpscale = "UPSCALE" // 放大
)

func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney"

Expand All @@ -186,6 +253,9 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
}
}

action := midjRequest.Action

if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
Expand All @@ -199,7 +269,44 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
originTask := model.GetByMJId(midjRequest.TaskId)
mjId := ""
if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return &MidjourneyResponse{
Code: 4,
Description: "content_is_required",
}
}
params := convertSimpleChangeParams(midjRequest.Content)
if params == nil {
return &MidjourneyResponse{
Code: 4,
Description: "content_parse_failed",
}
}
mjId = params.ID
action = params.Action
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Expand Down Expand Up @@ -229,23 +336,6 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
} else if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
}

// map model name
Expand Down Expand Up @@ -293,17 +383,17 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
requestBody = c.Request.Body
}

modelRatio := common.GetModelRatio(imageModel)
modelPrice := common.GetModelPrice("mj_" + strings.ToLower(action))
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)

sizeRatio := 1.0
if midjRequest.Action == "UPSCALE" {
sizeRatio = 0.2
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}

quota := int(ratio * sizeRatio * 1000)
quota := int(ratio * common.QuotaPerUnit)

if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{
Expand Down Expand Up @@ -369,7 +459,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, action)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
Expand Down Expand Up @@ -423,7 +513,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
Action: action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Expand Down Expand Up @@ -504,3 +594,38 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
return nil
}

type taskChangeParams struct {
ID string
Action string
Index int
}

func convertSimpleChangeParams(content string) *taskChangeParams {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}

action := strings.ToLower(split[1])
changeParams := &taskChangeParams{}
changeParams.ID = split[0]

if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}

index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}
12 changes: 10 additions & 2 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ const (
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudio
)

Expand Down Expand Up @@ -263,6 +265,7 @@ type MidjourneyRequest struct {
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}

type MidjourneyResponse struct {
Expand Down Expand Up @@ -342,14 +345,19 @@ func RelayMidjourney(c *gin.Context) {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition
}

var err *MidjourneyResponse
switch relayMode {
case RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch:
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
err = relayMidjourneyTask(c, relayMode)
default:
err = relayMidjourneySubmit(c, relayMode)
Expand Down
Loading

0 comments on commit 5c747df

Please sign in to comment.