[go: up one dir, main page]

Skip to content

Commit

Permalink
feat: support cohere rerank
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Jul 6, 2024
1 parent afe02c6 commit 8af4e28
Show file tree
Hide file tree
Showing 25 changed files with 347 additions and 11 deletions.
2 changes: 2 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
fallthrough
case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c, relayMode)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
default:
err = relay.TextHelper(c)
}
Expand Down
19 changes: 19 additions & 0 deletions dto/rerank.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dto

type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
}

type RerankResponseDocument struct {
Document any `json:"document"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}

type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
}
2 changes: 2 additions & 0 deletions relay/channel/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
Expand Down
7 changes: 7 additions & 0 deletions relay/channel/ali/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (
type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {

}
Expand Down Expand Up @@ -53,6 +56,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
9 changes: 9 additions & 0 deletions relay/channel/aws/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ type Adaptor struct {
RequestMode int
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me

}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
Expand Down Expand Up @@ -53,6 +58,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return claudeReq, err
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
Expand Down
9 changes: 9 additions & 0 deletions relay/channel/baidu/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me

}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {

}
Expand Down Expand Up @@ -108,6 +113,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
9 changes: 9 additions & 0 deletions relay/channel/claude/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ type Adaptor struct {
RequestMode int
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me

}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
Expand Down Expand Up @@ -59,6 +64,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
24 changes: 20 additions & 4 deletions relay/channel/cohere/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)

type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}

func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
}

func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
Expand All @@ -34,11 +42,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return requestConvertRerank2Cohere(request), nil
}

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
}
return
}
Expand Down
1 change: 1 addition & 0 deletions relay/channel/cohere/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cohere

var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
}

var ChannelName = "cohere"
15 changes: 15 additions & 0 deletions relay/channel/cohere/dto.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package cohere

import "one-api/dto"

type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
Expand Down Expand Up @@ -28,6 +30,19 @@ type CohereResponseResult struct {
Meta CohereMeta `json:"meta"`
}

type CohereRerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents"`
}

type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta CohereMeta `json:"meta"`
}

type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`
Expand Down
53 changes: 53 additions & 0 deletions relay/channel/cohere/relay-cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
return &cohereReq
}

func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
cohereReq := CohereRerankRequest{
Query: rerankRequest.Query,
Documents: rerankRequest.Documents,
Model: rerankRequest.Model,
TopN: rerankRequest.TopN,
ReturnDocuments: true,
}
for _, doc := range rerankRequest.Documents {
cohereReq.Documents = append(cohereReq.Documents, doc)
}
return &cohereReq
}

func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
Expand Down Expand Up @@ -194,3 +208,42 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = 0
usage.TotalTokens = info.PromptTokens
} else {
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
}

var rerankResp dto.RerankResponse
rerankResp.Results = cohereResp.Results
rerankResp.Usage = usage

jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
9 changes: 9 additions & 0 deletions relay/channel/dify/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (
type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me

}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}

Expand All @@ -34,6 +39,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return requestOpenAI2Dify(*request), nil
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
7 changes: 7 additions & 0 deletions relay/channel/gemini/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (
type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}

Expand Down Expand Up @@ -56,6 +59,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return CovertGemini2OpenAI(*request), nil
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
7 changes: 7 additions & 0 deletions relay/channel/ollama/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import (
type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}

Expand Down Expand Up @@ -45,6 +48,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
}
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
7 changes: 7 additions & 0 deletions relay/channel/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type Adaptor struct {
ChannelType int
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
}
Expand Down
9 changes: 9 additions & 0 deletions relay/channel/palm/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import (
type Adaptor struct {
}

func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
//TODO implement me

}

func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}

Expand All @@ -35,6 +40,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return request, nil
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}

func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
Expand Down
Loading

0 comments on commit 8af4e28

Please sign in to comment.