Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
@@ -1,209 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DefaultController Controller
|
||||
|
||||
// Controller controls communications between master and slave
|
||||
type Controller interface {
|
||||
// Handle heartbeat sent from master
|
||||
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
|
||||
|
||||
// Get Aria2 Instance by master node ID
|
||||
GetAria2Instance(string) (common.Aria2, error)
|
||||
|
||||
// Send event change message to master node
|
||||
SendNotification(string, string, mq.Message) error
|
||||
|
||||
// Submit async task into task pool
|
||||
SubmitTask(string, interface{}, string, func(interface{})) error
|
||||
|
||||
// Get master node info
|
||||
GetMasterInfo(string) (*MasterInfo, error)
|
||||
|
||||
// Get master Oauth based policy credential
|
||||
GetPolicyOauthToken(string, uint) (string, error)
|
||||
}
|
||||
|
||||
type slaveController struct {
|
||||
masters map[string]MasterInfo
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// info of master node
|
||||
type MasterInfo struct {
|
||||
ID string
|
||||
TTL int
|
||||
URL *url.URL
|
||||
// used to invoke aria2 rpc calls
|
||||
Instance Node
|
||||
Client request.Client
|
||||
|
||||
jobTracker map[string]bool
|
||||
}
|
||||
|
||||
func InitController() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
gob.Register(rpc.StatusInfo{})
|
||||
}
|
||||
|
||||
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
req.Node.AfterFind()
|
||||
|
||||
// close old node if exist
|
||||
origin, ok := c.masters[req.SiteID]
|
||||
|
||||
if (ok && req.IsUpdate) || !ok {
|
||||
if ok {
|
||||
origin.Instance.Kill()
|
||||
}
|
||||
|
||||
masterUrl, err := url.Parse(req.SiteURL)
|
||||
if err != nil {
|
||||
return serializer.NodePingResp{}, err
|
||||
}
|
||||
|
||||
c.masters[req.SiteID] = MasterInfo{
|
||||
ID: req.SiteID,
|
||||
URL: masterUrl,
|
||||
TTL: req.CredentialTTL,
|
||||
Client: request.NewClient(
|
||||
request.WithEndpoint(masterUrl.String()),
|
||||
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
}, int64(req.CredentialTTL)),
|
||||
),
|
||||
jobTracker: make(map[string]bool),
|
||||
Instance: NewNodeFromDBModel(&model.Node{
|
||||
Model: gorm.Model{ID: req.Node.ID},
|
||||
MasterKey: req.Node.MasterKey,
|
||||
Type: model.MasterNodeType,
|
||||
Aria2Enabled: req.Node.Aria2Enabled,
|
||||
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return node.Instance.GetAria2Instance(), nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
body := bytes.Buffer{}
|
||||
enc := gob.NewEncoder(&body)
|
||||
if err := enc.Encode(&msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"PUT",
|
||||
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
|
||||
&body,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
if _, ok := node.jobTracker[hash]; ok {
|
||||
// 任务已存在,直接返回
|
||||
return nil
|
||||
}
|
||||
|
||||
node.jobTracker[hash] = true
|
||||
submitter(job)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetMasterInfo 获取主机节点信息
|
||||
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证
|
||||
func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"GET",
|
||||
fmt.Sprintf("/api/v3/slave/credential/%d", policyID),
|
||||
nil,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return "", ErrMasterNotFound
|
||||
}
|
||||
@@ -1,385 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitController(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
InitController()
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlaveController_HandleHeartBeat(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
|
||||
// first heart beat
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "2",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
a.Len(c.masters, 2)
|
||||
}
|
||||
|
||||
// second heart beat, no fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Empty(c.masters["1"].URL)
|
||||
}
|
||||
|
||||
// second heart beat, fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
|
||||
// second heart beat, fresh, url illegal
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: string([]byte{0x7f}),
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.Error(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
}
|
||||
|
||||
type nodeMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n nodeMock) Init(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsFeatureEnabled(feature string) bool {
|
||||
args := n.Called(feature)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
n.Called(callback)
|
||||
}
|
||||
|
||||
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
args := n.Called(req)
|
||||
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsActive() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) GetAria2Instance() common.Aria2 {
|
||||
args := n.Called()
|
||||
return args.Get(0).(common.Aria2)
|
||||
}
|
||||
|
||||
func (n nodeMock) ID() uint {
|
||||
args := n.Called()
|
||||
return args.Get(0).(uint)
|
||||
}
|
||||
|
||||
func (n nodeMock) Kill() {
|
||||
n.Called()
|
||||
}
|
||||
|
||||
func (n nodeMock) IsMater() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) MasterAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) SlaveAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) DBModel() *model.Node {
|
||||
args := n.Called()
|
||||
return args.Get(0).(*model.Node)
|
||||
}
|
||||
|
||||
func TestSlaveController_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Instance: mockNode},
|
||||
},
|
||||
}
|
||||
|
||||
// node node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("2")
|
||||
a.Nil(res)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
}
|
||||
|
||||
// node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("1")
|
||||
a.NotNil(res)
|
||||
a.NoError(err)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type requestMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||
}
|
||||
|
||||
func TestSlaveController_SendNotification(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
|
||||
}
|
||||
|
||||
// gob encode error
|
||||
{
|
||||
type randomType struct{}
|
||||
a.Error(c.SendNotification("1", "", mq.Message{
|
||||
Content: randomType{},
|
||||
}))
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Error(c.SendNotification("1", "s1", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_SubmitTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {
|
||||
jobTracker: map[string]bool{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
submitted := false
|
||||
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
|
||||
submitted = true
|
||||
}))
|
||||
a.True(submitted)
|
||||
}
|
||||
|
||||
// job already submitted
|
||||
{
|
||||
submitted := false
|
||||
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
|
||||
submitted = true
|
||||
}))
|
||||
a.False(submitted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_GetMasterInfo(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
res, err := c.GetMasterInfo("2")
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
res, err := c.GetMasterInfo("1")
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_GetOneDriveToken(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
res, err := c.GetPolicyOauthToken("2", 1)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetPolicyOauthToken("1", 1)
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetPolicyOauthToken("1", 1)
|
||||
a.Equal(1, err.(serializer.AppError).Code)
|
||||
a.Empty(res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetPolicyOauthToken("1", 1)
|
||||
a.NoError(err)
|
||||
a.Equal("expected", res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
|
||||
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil)
|
||||
)
|
||||
@@ -1,272 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
deleteTempFileDuration = 60 * time.Second
|
||||
statusRetryDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type MasterNode struct {
|
||||
Model *model.Node
|
||||
aria2RPC rpcService
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type rpcService struct {
|
||||
Caller rpc.Client
|
||||
Initialized bool
|
||||
|
||||
retryDuration time.Duration
|
||||
deletePaddingDuration time.Duration
|
||||
parent *MasterNode
|
||||
options *clientOptions
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *MasterNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.aria2RPC.parent = node
|
||||
node.aria2RPC.retryDuration = statusRetryDuration
|
||||
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
|
||||
node.lock.Unlock()
|
||||
|
||||
node.lock.RLock()
|
||||
if node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return
|
||||
}
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (node *MasterNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
return &serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (node *MasterNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *MasterNode) IsActive() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Kill 结束aria2请求
|
||||
func (node *MasterNode) Kill() {
|
||||
if node.aria2RPC.Caller != nil {
|
||||
node.aria2RPC.Caller.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取主机Aria2实例
|
||||
func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
if !node.aria2RPC.Initialized {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
defer node.lock.RUnlock()
|
||||
return &node.aria2RPC
|
||||
}
|
||||
|
||||
func (node *MasterNode) IsMater() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (node *MasterNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
func (r *rpcService) Init() error {
|
||||
r.parent.lock.Lock()
|
||||
defer r.parent.lock.Unlock()
|
||||
r.Initialized = false
|
||||
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if r.Caller != nil {
|
||||
r.Caller.Close()
|
||||
}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
|
||||
return err
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
|
||||
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to parse aria2 options: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.options = &clientOptions{
|
||||
Options: globalOptions,
|
||||
}
|
||||
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
|
||||
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
|
||||
|
||||
r.Caller = caller
|
||||
r.Initialized = err == nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||
r.parent.lock.RLock()
|
||||
// 生成存储路径
|
||||
guid, _ := uuid.NewV4()
|
||||
path := filepath.Join(
|
||||
r.parent.Model.Aria2OptionsSerialized.TempPath,
|
||||
"aria2",
|
||||
guid.String(),
|
||||
)
|
||||
r.parent.lock.RUnlock()
|
||||
|
||||
// 创建下载任务
|
||||
options := map[string]interface{}{
|
||||
"dir": path,
|
||||
}
|
||||
for k, v := range r.options.Options {
|
||||
options[k] = v
|
||||
}
|
||||
for k, v := range groupOptions {
|
||||
options[k] = v
|
||||
}
|
||||
|
||||
gid, err := r.Caller.AddURI(task.Source, options)
|
||||
if err != nil || gid == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gid, nil
|
||||
}
|
||||
|
||||
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
res, err := r.Caller.TellStatus(task.GID)
|
||||
if err != nil {
|
||||
// 失败后重试
|
||||
util.Log().Debug("Failed to get download task status, please retry later: %s", err)
|
||||
time.Sleep(r.retryDuration)
|
||||
res, err = r.Caller.TellStatus(task.GID)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *rpcService) Cancel(task *model.Download) error {
|
||||
// 取消下载任务
|
||||
_, err := r.Caller.Remove(task.GID)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) Select(task *model.Download, files []int) error {
|
||||
var selected = make([]string, len(files))
|
||||
for i := 0; i < len(files); i++ {
|
||||
selected[i] = strconv.Itoa(files[i])
|
||||
}
|
||||
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) GetConfig() model.Aria2Option {
|
||||
r.parent.lock.RLock()
|
||||
defer r.parent.lock.RUnlock()
|
||||
|
||||
return r.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *rpcService) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
// 避免被aria2占用,异步执行删除
|
||||
go func(d time.Duration, src string) {
|
||||
time.Sleep(d)
|
||||
err := os.RemoveAll(src)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
|
||||
}
|
||||
}(s.deletePaddingDuration, task.Parent)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMasterNode_Init(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{}
|
||||
m.Init(&model.Node{Status: model.NodeSuspend})
|
||||
a.Equal(model.NodeSuspend, m.DBModel().Status)
|
||||
m.Init(&model.Node{Aria2Enabled: true})
|
||||
}
|
||||
|
||||
func TestMasterNode_DummyMethods(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Model.ID = 5
|
||||
a.Equal(m.Model.ID, m.ID())
|
||||
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
|
||||
a.True(m.IsActive())
|
||||
a.True(m.IsMater())
|
||||
|
||||
m.SubscribeStatusChange(func(isActive bool, id uint) {})
|
||||
}
|
||||
|
||||
func TestMasterNode_IsFeatureEnabled(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.False(m.IsFeatureEnabled("aria2"))
|
||||
a.False(m.IsFeatureEnabled("random"))
|
||||
m.Model.Aria2Enabled = true
|
||||
a.True(m.IsFeatureEnabled("aria2"))
|
||||
}
|
||||
|
||||
func TestMasterNode_AuthInstance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.MasterAuthInstance())
|
||||
a.NotNil(m.SlaveAuthInstance())
|
||||
}
|
||||
|
||||
func TestMasterNode_Kill(t *testing.T) {
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Kill()
|
||||
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
m.Kill()
|
||||
}
|
||||
|
||||
func TestMasterNode_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
aria2RPC: rpcService{},
|
||||
}
|
||||
|
||||
m.aria2RPC.parent = m
|
||||
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.Model.Aria2Enabled = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.aria2RPC.Initialized = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
}
|
||||
|
||||
func TestRpcService_Init(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{
|
||||
Aria2OptionsSerialized: model.Aria2Option{
|
||||
Options: "{",
|
||||
},
|
||||
},
|
||||
aria2RPC: rpcService{},
|
||||
}
|
||||
m.aria2RPC.parent = m
|
||||
|
||||
// failed to decode address
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
|
||||
a.Error(m.aria2RPC.Init())
|
||||
}
|
||||
|
||||
// failed to decode options
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = ""
|
||||
a.Error(m.aria2RPC.Init())
|
||||
}
|
||||
|
||||
// failed to initialized
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = ""
|
||||
m.Model.Aria2OptionsSerialized.Options = "{}"
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
a.Error(m.aria2RPC.Init())
|
||||
a.False(m.aria2RPC.Initialized)
|
||||
}
|
||||
}
|
||||
|
||||
func getTestRPCNode() *MasterNode {
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{
|
||||
Aria2OptionsSerialized: model.Aria2Option{},
|
||||
},
|
||||
aria2RPC: rpcService{
|
||||
options: &clientOptions{
|
||||
Options: map[string]interface{}{"1": "1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
m.aria2RPC.parent = m
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
return m
|
||||
}
|
||||
|
||||
func TestRpcService_CreateTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
func TestRpcService_Status(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
res, err := m.aria2RPC.Status(&model.Download{})
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
func TestRpcService_Cancel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
a.Error(m.aria2RPC.Cancel(&model.Download{}))
|
||||
}
|
||||
|
||||
func TestRpcService_Select(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
a.NotNil(m.aria2RPC.GetConfig())
|
||||
a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
|
||||
}
|
||||
|
||||
func TestRpcService_DeleteTempFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
fdName := "TestRpcService_DeleteTempFile"
|
||||
a.NoError(os.Mkdir(fdName, 0644))
|
||||
|
||||
a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
a.False(util.Exists(fdName))
|
||||
}
|
||||
@@ -1,60 +1,413 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/node"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/task"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/qbittorrent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/queue"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
// Init a node from database model
|
||||
Init(node *model.Node)
|
||||
type (
|
||||
Node interface {
|
||||
fs.StatelessUploadManager
|
||||
ID() int
|
||||
Name() string
|
||||
IsMaster() bool
|
||||
// CreateTask creates a task on the node. It does not have effect on master node.
|
||||
CreateTask(ctx context.Context, taskType string, state string) (int, error)
|
||||
// GetTask returns the task summary of the task with the given id.
|
||||
GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error)
|
||||
// CleanupFolders cleans up the given folders on the node.
|
||||
CleanupFolders(ctx context.Context, folders ...string) error
|
||||
// AuthInstance returns the auth instance for the node.
|
||||
AuthInstance() auth.Auth
|
||||
// CreateDownloader creates a downloader instance from the node for remote download tasks.
|
||||
CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error)
|
||||
// Settings returns the settings of the node.
|
||||
Settings(ctx context.Context) *types.NodeSetting
|
||||
}
|
||||
|
||||
// Check if given feature is enabled
|
||||
IsFeatureEnabled(feature string) bool
|
||||
// Request body for creating tasks on slave node
|
||||
CreateSlaveTask struct {
|
||||
Type string `json:"type"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// Subscribe node status change to a callback function
|
||||
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||
// Request body for cleaning up folders on slave node
|
||||
FolderCleanup struct {
|
||||
Path []string `json:"path" binding:"required"`
|
||||
}
|
||||
|
||||
// Ping the node
|
||||
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||
SlaveTaskSummary struct {
|
||||
Status task.Status `json:"status"`
|
||||
Error string `json:"error"`
|
||||
PrivateState string `json:"private_state"`
|
||||
Progress queue.Progresses `json:"progress,omitempty"`
|
||||
}
|
||||
|
||||
// Returns if the node is active
|
||||
IsActive() bool
|
||||
MasterSiteUrlCtx struct{}
|
||||
MasterSiteVersionCtx struct{}
|
||||
MasterSiteIDCtx struct{}
|
||||
SlaveNodeIDCtx struct{}
|
||||
masterNode struct {
|
||||
nodeBase
|
||||
client request.Client
|
||||
}
|
||||
)
|
||||
|
||||
// Get instances for aria2 calls
|
||||
GetAria2Instance() common.Aria2
|
||||
|
||||
// Returns unique id of this node
|
||||
ID() uint
|
||||
|
||||
// Kill node and recycle resources
|
||||
Kill()
|
||||
|
||||
// Returns if current node is master node
|
||||
IsMater() bool
|
||||
|
||||
// Get auth instance used to check RPC call from slave to master
|
||||
MasterAuthInstance() auth.Auth
|
||||
|
||||
// Get auth instance used to check RPC call from master to slave
|
||||
SlaveAuthInstance() auth.Auth
|
||||
|
||||
// Get node DB model
|
||||
DBModel() *model.Node
|
||||
func newNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) Node {
|
||||
if model.Type == node.TypeMaster {
|
||||
return newMasterNode(model, config, settings)
|
||||
}
|
||||
return newSlaveNode(ctx, model, config, settings)
|
||||
}
|
||||
|
||||
// Create new node from DB model
|
||||
func NewNodeFromDBModel(node *model.Node) Node {
|
||||
switch node.Type {
|
||||
case model.SlaveNodeType:
|
||||
slave := &SlaveNode{}
|
||||
slave.Init(node)
|
||||
return slave
|
||||
default:
|
||||
master := &MasterNode{}
|
||||
master.Init(node)
|
||||
return master
|
||||
func newMasterNode(model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *masterNode {
|
||||
n := &masterNode{
|
||||
nodeBase: nodeBase{
|
||||
model: model,
|
||||
},
|
||||
}
|
||||
|
||||
if config.System().Mode == conf.SlaveMode {
|
||||
n.client = request.NewClient(config,
|
||||
request.WithCorrelationID(),
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
[]byte(config.Slave().Secret),
|
||||
}, int64(config.Slave().SignatureTTL)),
|
||||
)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *masterNode) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "prepare")
|
||||
resp, err := b.client.Request(
|
||||
"PUT",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
uploadRequest := &fs.StatelessPrepareUploadResponse{}
|
||||
resp.GobDecode(uploadRequest)
|
||||
|
||||
return uploadRequest, nil
|
||||
}
|
||||
|
||||
func (b *masterNode) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "complete")
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *masterNode) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "failed")
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *masterNode) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "create")
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (b *masterNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
|
||||
return NewDownloader(ctx, c, settings, b.Settings(ctx))
|
||||
}
|
||||
|
||||
// NewDownloader creates a new downloader instance from the node for remote download tasks.
|
||||
func NewDownloader(ctx context.Context, c request.Client, settings setting.Provider, options *types.NodeSetting) (downloader.Downloader, error) {
|
||||
if options.Provider == types.DownloaderProviderQBittorrent {
|
||||
return qbittorrent.NewClient(logging.FromContext(ctx), c, settings, options.QBittorrentSetting)
|
||||
} else if options.Provider == types.DownloaderProviderAria2 {
|
||||
return aria2.New(logging.FromContext(ctx), settings, options.Aria2Setting), nil
|
||||
} else if options.Provider == "" {
|
||||
return nil, errors.New("downloader not configured for this node")
|
||||
} else {
|
||||
return nil, errors.New("unknown downloader provider")
|
||||
}
|
||||
}
|
||||
|
||||
type slaveNode struct {
|
||||
nodeBase
|
||||
client request.Client
|
||||
}
|
||||
|
||||
func newSlaveNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *slaveNode {
|
||||
siteBasic := settings.SiteBasic(ctx)
|
||||
return &slaveNode{
|
||||
nodeBase: nodeBase{
|
||||
model: model,
|
||||
},
|
||||
client: request.NewClient(config,
|
||||
request.WithCorrelationID(),
|
||||
request.WithSlaveMeta(model.ID),
|
||||
request.WithMasterMeta(siteBasic.ID, settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()),
|
||||
request.WithCredential(auth.HMACAuth{[]byte(model.SlaveKey)}, int64(settings.SlaveRequestSignTTL(ctx))),
|
||||
request.WithEndpoint(model.Server)),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *slaveNode) CreateTask(ctx context.Context, taskType string, state string) (int, error) {
|
||||
reqBody, err := json.Marshal(&CreateSlaveTask{
|
||||
Type: taskType,
|
||||
State: state,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := n.client.Request(
|
||||
"PUT",
|
||||
constants.APIPrefixSlave+"/task",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return 0, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
taskId := 0
|
||||
if resp.GobDecode(&taskId); taskId > 0 {
|
||||
return taskId, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unexpected response data: %v", resp.Data)
|
||||
}
|
||||
|
||||
func (n *slaveNode) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) {
|
||||
resp, err := n.client.Request(
|
||||
"GET",
|
||||
routes.SlaveGetTaskRoute(id, clearOnComplete),
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
summary := &SlaveTaskSummary{}
|
||||
resp.GobDecode(summary)
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func (b *slaveNode) CleanupFolders(ctx context.Context, folders ...string) error {
|
||||
args := &FolderCleanup{
|
||||
Path: folders,
|
||||
}
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/task/cleanup",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *slaveNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
|
||||
return slave.NewSlaveDownloader(b.client, b.Settings(ctx)), nil
|
||||
}
|
||||
|
||||
type nodeBase struct {
|
||||
model *ent.Node
|
||||
}
|
||||
|
||||
func (b *nodeBase) ID() int {
|
||||
return b.model.ID
|
||||
}
|
||||
|
||||
func (b *nodeBase) Name() string {
|
||||
return b.model.Name
|
||||
}
|
||||
|
||||
func (b *nodeBase) IsMaster() bool {
|
||||
return b.model.Type == node.TypeMaster
|
||||
}
|
||||
|
||||
func (b *nodeBase) CreateTask(ctx context.Context, taskType string, state string) (int, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) AuthInstance() auth.Auth {
|
||||
return auth.HMACAuth{[]byte(b.model.SlaveKey)}
|
||||
}
|
||||
|
||||
func (b *nodeBase) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CleanupFolders(ctx context.Context, folders ...string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) Settings(ctx context.Context) *types.NodeSetting {
|
||||
return b.model.Settings
|
||||
}
|
||||
|
||||
func NodeIdFromContext(ctx context.Context) int {
|
||||
nodeIdStr, ok := ctx.Value(SlaveNodeIDCtx{}).(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
nodeId, _ := strconv.Atoi(nodeIdStr)
|
||||
return nodeId
|
||||
}
|
||||
|
||||
func MasterSiteUrlFromContext(ctx context.Context) string {
|
||||
if u, ok := ctx.Value(MasterSiteUrlCtx{}).(string); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewNodeFromDBModel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
|
||||
Type: model.SlaveNodeType,
|
||||
}))
|
||||
a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
|
||||
Type: model.MasterNodeType,
|
||||
}))
|
||||
}
|
||||
@@ -1,190 +1,203 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/node"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var Default *NodePool
|
||||
|
||||
// 需要分类的节点组
|
||||
var featureGroup = []string{"aria2"}
|
||||
|
||||
// Pool 节点池
|
||||
type Pool interface {
|
||||
// Returns active node selected by given feature and load balancer
|
||||
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
|
||||
|
||||
// Returns node by ID
|
||||
GetNodeByID(id uint) Node
|
||||
|
||||
// Add given node into pool. If node existed, refresh node.
|
||||
Add(node *model.Node)
|
||||
|
||||
// Delete and kill node from pool by given node id
|
||||
Delete(id uint)
|
||||
type NodePool interface {
|
||||
// Upsert updates or inserts a node into the pool.
|
||||
Upsert(ctx context.Context, node *ent.Node)
|
||||
// Get returns a node with the given capability and preferred node id. `allowed` is a list of allowed node ids.
|
||||
// If `allowed` is empty, all nodes with the capability are considered.
|
||||
Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error)
|
||||
}
|
||||
|
||||
// NodePool 通用节点池
|
||||
type NodePool struct {
|
||||
active map[uint]Node
|
||||
inactive map[uint]Node
|
||||
type (
|
||||
weightedNodePool struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
featureMap map[string][]Node
|
||||
conf conf.ConfigProvider
|
||||
settings setting.Provider
|
||||
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Init 初始化从机节点池
|
||||
func Init() {
|
||||
Default = &NodePool{}
|
||||
Default.Init()
|
||||
if err := Default.initFromDB(); err != nil {
|
||||
util.Log().Warning("Failed to initialize node pool: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (pool *NodePool) Init() {
|
||||
pool.lock.Lock()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
pool.featureMap = make(map[string][]Node)
|
||||
pool.active = make(map[uint]Node)
|
||||
pool.inactive = make(map[uint]Node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) buildIndexMap() {
|
||||
pool.lock.Lock()
|
||||
for _, feature := range featureGroup {
|
||||
pool.featureMap[feature] = make([]Node, 0)
|
||||
nodes map[types.NodeCapability][]*nodeItem
|
||||
}
|
||||
|
||||
for _, v := range pool.active {
|
||||
for _, feature := range featureGroup {
|
||||
if v.IsFeatureEnabled(feature) {
|
||||
pool.featureMap[feature] = append(pool.featureMap[feature], v)
|
||||
nodeItem struct {
|
||||
node Node
|
||||
weight int
|
||||
current int
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoAvailableNode = fmt.Errorf("no available node found")
|
||||
|
||||
supportedCapabilities = []types.NodeCapability{
|
||||
types.NodeCapabilityNone,
|
||||
types.NodeCapabilityCreateArchive,
|
||||
types.NodeCapabilityExtractArchive,
|
||||
types.NodeCapabilityRemoteDownload,
|
||||
}
|
||||
)
|
||||
|
||||
func NewNodePool(ctx context.Context, l logging.Logger, config conf.ConfigProvider, settings setting.Provider,
|
||||
client inventory.NodeClient) (NodePool, error) {
|
||||
nodes, err := client.ListActiveNodes(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list active nodes: %w", err)
|
||||
}
|
||||
|
||||
pool := &weightedNodePool{
|
||||
nodes: make(map[types.NodeCapability][]*nodeItem),
|
||||
conf: config,
|
||||
settings: settings,
|
||||
}
|
||||
for _, node := range nodes {
|
||||
for _, capability := range supportedCapabilities {
|
||||
// If current capability is enabled, add it to pool slot.
|
||||
if capability == types.NodeCapabilityNone ||
|
||||
(node.Capabilities != nil && node.Capabilities.Enabled(int(capability))) {
|
||||
if _, ok := pool.nodes[capability]; !ok {
|
||||
pool.nodes[capability] = make([]*nodeItem, 0)
|
||||
}
|
||||
|
||||
l.Debug("Add node %q to capability slot %d with weight %d", node.Name, capability, node.Weight)
|
||||
pool.nodes[capability] = append(pool.nodes[capability], &nodeItem{
|
||||
node: newNode(ctx, node, config, settings),
|
||||
weight: node.Weight,
|
||||
current: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func (pool *NodePool) GetNodeByID(id uint) Node {
|
||||
pool.lock.RLock()
|
||||
defer pool.lock.RUnlock()
|
||||
func (p *weightedNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) {
|
||||
l := logging.FromContext(ctx)
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
if node, ok := pool.active[id]; ok {
|
||||
return node
|
||||
nodes, ok := p.nodes[capability]
|
||||
if !ok || len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode)
|
||||
}
|
||||
|
||||
return pool.inactive[id]
|
||||
}
|
||||
var selected *nodeItem
|
||||
|
||||
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
|
||||
util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive)
|
||||
var node Node
|
||||
pool.lock.Lock()
|
||||
if n, ok := pool.inactive[id]; ok {
|
||||
node = n
|
||||
delete(pool.inactive, id)
|
||||
} else {
|
||||
node = pool.active[id]
|
||||
delete(pool.active, id)
|
||||
}
|
||||
|
||||
if isActive {
|
||||
pool.active[id] = node
|
||||
} else {
|
||||
pool.inactive[id] = node
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
pool.buildIndexMap()
|
||||
}
|
||||
|
||||
func (pool *NodePool) initFromDB() error {
|
||||
nodes, err := model.GetNodesByStatus(model.NodeActive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pool.lock.Lock()
|
||||
for i := 0; i < len(nodes); i++ {
|
||||
pool.add(&nodes[i])
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
pool.buildIndexMap()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pool *NodePool) add(node *model.Node) {
|
||||
newNode := NewNodeFromDBModel(node)
|
||||
if newNode.IsActive() {
|
||||
pool.active[node.ID] = newNode
|
||||
} else {
|
||||
pool.inactive[node.ID] = newNode
|
||||
}
|
||||
|
||||
// 订阅节点状态变更
|
||||
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
|
||||
pool.nodeStatusChange(isActive, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (pool *NodePool) Add(node *model.Node) {
|
||||
pool.lock.Lock()
|
||||
defer pool.buildIndexMap()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
var (
|
||||
old Node
|
||||
ok bool
|
||||
)
|
||||
if old, ok = pool.active[node.ID]; !ok {
|
||||
old, ok = pool.inactive[node.ID]
|
||||
}
|
||||
if old != nil {
|
||||
go old.Init(node)
|
||||
return
|
||||
}
|
||||
|
||||
pool.add(node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) Delete(id uint) {
|
||||
pool.lock.Lock()
|
||||
defer pool.buildIndexMap()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
if node, ok := pool.active[id]; ok {
|
||||
node.Kill()
|
||||
delete(pool.active, id)
|
||||
return
|
||||
}
|
||||
|
||||
if node, ok := pool.inactive[id]; ok {
|
||||
node.Kill()
|
||||
delete(pool.inactive, id)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
|
||||
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
|
||||
pool.lock.RLock()
|
||||
defer pool.lock.RUnlock()
|
||||
if nodes, ok := pool.featureMap[feature]; ok {
|
||||
err, res := lb.NextPeer(nodes)
|
||||
if err == nil {
|
||||
return nil, res.(Node)
|
||||
if preferred > 0 {
|
||||
// First try to find the preferred node.
|
||||
for _, n := range nodes {
|
||||
if n.node.ID() == preferred {
|
||||
selected = n
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return err, nil
|
||||
if selected == nil {
|
||||
l.Debug("Preferred node %d not found, fallback to select a node with the least current weight", preferred)
|
||||
}
|
||||
}
|
||||
|
||||
return ErrFeatureNotExist, nil
|
||||
if selected == nil {
|
||||
// If no preferred one, or the preferred one is not available, select a node with the least current weight.
|
||||
|
||||
// Total weight of all items.
|
||||
var total int
|
||||
|
||||
// Loop through the list of items and add the item's weight to the current weight.
|
||||
// Also increment the total weight counter.
|
||||
var maxNode *nodeItem
|
||||
for _, item := range nodes {
|
||||
item.current += max(1, item.weight)
|
||||
total += max(1, item.weight)
|
||||
|
||||
// Select the item with max weight.
|
||||
if maxNode == nil || item.current > maxNode.current {
|
||||
maxNode = item
|
||||
}
|
||||
}
|
||||
|
||||
// Select the item with the max weight.
|
||||
selected = maxNode
|
||||
if selected == nil {
|
||||
return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode)
|
||||
}
|
||||
|
||||
l.Debug("Selected node %q with weight=%d, current=%d, total=%d", selected.node.Name(), selected.weight, maxNode.current, total)
|
||||
|
||||
// Reduce the current weight of the selected item by the total weight.
|
||||
maxNode.current -= total
|
||||
}
|
||||
|
||||
return selected.node, nil
|
||||
}
|
||||
|
||||
func (p *weightedNodePool) Upsert(ctx context.Context, n *ent.Node) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
for _, capability := range supportedCapabilities {
|
||||
_, index, found := lo.FindIndexOf(p.nodes[capability], func(i *nodeItem) bool {
|
||||
return i.node.ID() == n.ID
|
||||
})
|
||||
if capability == types.NodeCapabilityNone ||
|
||||
(n.Capabilities != nil && n.Capabilities.Enabled(int(capability))) {
|
||||
if n.Status != node.StatusActive && found {
|
||||
// Remove inactive node
|
||||
p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...)
|
||||
continue
|
||||
}
|
||||
|
||||
if found {
|
||||
p.nodes[capability][index].node = newNode(ctx, n, p.conf, p.settings)
|
||||
} else {
|
||||
p.nodes[capability] = append(p.nodes[capability], &nodeItem{
|
||||
node: newNode(ctx, n, p.conf, p.settings),
|
||||
weight: n.Weight,
|
||||
current: 0,
|
||||
})
|
||||
}
|
||||
} else if found {
|
||||
// Capability changed, remove the old node.
|
||||
p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type slaveDummyNodePool struct {
|
||||
conf conf.ConfigProvider
|
||||
settings setting.Provider
|
||||
masterNode Node
|
||||
}
|
||||
|
||||
func NewSlaveDummyNodePool(ctx context.Context, config conf.ConfigProvider, settings setting.Provider) NodePool {
|
||||
return &slaveDummyNodePool{
|
||||
conf: config,
|
||||
settings: settings,
|
||||
masterNode: newNode(ctx, &ent.Node{
|
||||
ID: 0,
|
||||
Name: "Master",
|
||||
Type: node.TypeMaster,
|
||||
}, config, settings),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slaveDummyNodePool) Upsert(ctx context.Context, node *ent.Node) {
|
||||
}
|
||||
|
||||
func (s *slaveDummyNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) {
|
||||
return s.masterNode, nil
|
||||
}
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestInitFailed(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
|
||||
Init()
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestInitSuccess(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType))
|
||||
Init()
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestNodePool_GetNodeByID(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
mockNode := &nodeMock{}
|
||||
|
||||
// inactive
|
||||
{
|
||||
p.inactive[1] = mockNode
|
||||
a.Equal(mockNode, p.GetNodeByID(1))
|
||||
}
|
||||
|
||||
// active
|
||||
{
|
||||
delete(p.inactive, 1)
|
||||
p.active[1] = mockNode
|
||||
a.Equal(mockNode, p.GetNodeByID(1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodePool_NodeStatusChange(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
n := &MasterNode{Model: &model.Node{}}
|
||||
p.Init()
|
||||
p.inactive[1] = n
|
||||
|
||||
p.nodeStatusChange(true, 1)
|
||||
a.Len(p.inactive, 0)
|
||||
a.Equal(n, p.active[1])
|
||||
|
||||
p.nodeStatusChange(false, 1)
|
||||
a.Len(p.active, 0)
|
||||
a.Equal(n, p.inactive[1])
|
||||
|
||||
p.nodeStatusChange(false, 1)
|
||||
a.Len(p.active, 0)
|
||||
a.Equal(n, p.inactive[1])
|
||||
}
|
||||
|
||||
func TestNodePool_Add(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
|
||||
// new node
|
||||
{
|
||||
p.Add(&model.Node{})
|
||||
a.Len(p.active, 1)
|
||||
}
|
||||
|
||||
// old node
|
||||
{
|
||||
p.inactive[0] = p.active[0]
|
||||
delete(p.active, 0)
|
||||
p.Add(&model.Node{})
|
||||
a.Len(p.active, 0)
|
||||
a.Len(p.inactive, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodePool_Delete(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
|
||||
// active
|
||||
{
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("Kill")
|
||||
p.active[0] = mockNode
|
||||
p.Delete(0)
|
||||
a.Len(p.active, 0)
|
||||
a.Len(p.inactive, 0)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
p.Init()
|
||||
|
||||
// inactive
|
||||
{
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("Kill")
|
||||
p.inactive[0] = mockNode
|
||||
p.Delete(0)
|
||||
a.Len(p.active, 0)
|
||||
a.Len(p.inactive, 0)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodePool_BalanceNodeByFeature(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
|
||||
// success
|
||||
{
|
||||
p.featureMap["test"] = []Node{&MasterNode{}}
|
||||
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
|
||||
a.NoError(err)
|
||||
a.Equal(p.featureMap["test"][0], res)
|
||||
}
|
||||
|
||||
// NoNodes
|
||||
{
|
||||
p.featureMap["test"] = []Node{}
|
||||
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// No match feature
|
||||
{
|
||||
err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin"))
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
}
|
||||
}
|
||||
207
pkg/cluster/routes/routes.go
Normal file
207
pkg/cluster/routes/routes.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
)
|
||||
|
||||
const (
|
||||
IsDownloadQuery = "download"
|
||||
IsThumbQuery = "thumb"
|
||||
SlaveClearTaskRegistryQuery = "deleteOnComplete"
|
||||
)
|
||||
|
||||
var (
|
||||
masterPing *url.URL
|
||||
masterUserActivate *url.URL
|
||||
masterUserReset *url.URL
|
||||
masterHome *url.URL
|
||||
)
|
||||
|
||||
func init() {
|
||||
masterPing, _ = url.Parse(constants.APIPrefix + "/site/ping")
|
||||
masterUserActivate, _ = url.Parse("/session/activate")
|
||||
masterUserReset, _ = url.Parse("/session/reset")
|
||||
}
|
||||
|
||||
func FrontendHomeUrl(base *url.URL, path string) *url.URL {
|
||||
route, _ := url.Parse(fmt.Sprintf("/home"))
|
||||
q := route.Query()
|
||||
q.Set("path", path)
|
||||
route.RawQuery = q.Encode()
|
||||
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func MasterPingUrl(base *url.URL) *url.URL {
|
||||
return base.ResolveReference(masterPing)
|
||||
}
|
||||
|
||||
func MasterSlaveCallbackUrl(base *url.URL, driver, id, secret string) *url.URL {
|
||||
apiBaseURI, _ := url.Parse(path.Join(constants.APIPrefix+"/callback", driver, id, secret))
|
||||
return base.ResolveReference(apiBaseURI)
|
||||
}
|
||||
|
||||
func MasterUserActivateAPIUrl(base *url.URL, uid string) *url.URL {
|
||||
route, _ := url.Parse(constants.APIPrefix + "/user/activate/" + uid)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func MasterUserActivateUrl(base *url.URL) *url.URL {
|
||||
return base.ResolveReference(masterUserActivate)
|
||||
}
|
||||
|
||||
func MasterUserResetUrl(base *url.URL) *url.URL {
|
||||
return base.ResolveReference(masterUserReset)
|
||||
}
|
||||
|
||||
func MasterShareUrl(base *url.URL, id, password string) *url.URL {
|
||||
p := "/s/" + id
|
||||
if password != "" {
|
||||
p += ("/" + password)
|
||||
}
|
||||
route, _ := url.Parse(p)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func MasterDirectLink(base *url.URL, id, name string) *url.URL {
|
||||
p := path.Join("/f", id, url.PathEscape(name))
|
||||
route, _ := url.Parse(p)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
// MasterShareLongUrl generates a long share URL for redirect.
|
||||
func MasterShareLongUrl(id, password string) *url.URL {
|
||||
base, _ := url.Parse("/home")
|
||||
q := base.Query()
|
||||
|
||||
q.Set("path", fs.NewShareUri(id, password))
|
||||
base.RawQuery = q.Encode()
|
||||
return base
|
||||
}
|
||||
|
||||
func MasterArchiveDownloadUrl(base *url.URL, sessionID string) *url.URL {
|
||||
routes, err := url.Parse(path.Join(constants.APIPrefix, "file", "archive", sessionID, "archive.zip"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return base.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func MasterPolicyOAuthCallback(base *url.URL) *url.URL {
|
||||
if base.Scheme != "https" {
|
||||
base.Scheme = "https"
|
||||
}
|
||||
routes, err := url.Parse("/admin/policy/oauth")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return base.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func MasterGetCredentialUrl(base, key string) *url.URL {
|
||||
masterBase, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "credential", key))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return masterBase.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func MasterStatelessUrl(base, method string) *url.URL {
|
||||
masterBase, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "statelessUpload", method))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return masterBase.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func SlaveUploadUrl(base *url.URL, sessionID string) *url.URL {
|
||||
base.Path = path.Join(base.Path, constants.APIPrefixSlave, "/upload", sessionID)
|
||||
return base
|
||||
}
|
||||
|
||||
func MasterFileContentUrl(base *url.URL, entityId, name string, download, thumb bool, speed int64) *url.URL {
|
||||
name = url.PathEscape(name)
|
||||
|
||||
route, _ := url.Parse(constants.APIPrefix + fmt.Sprintf("/file/content/%s/%d/%s", entityId, speed, name))
|
||||
if base != nil {
|
||||
route = base.ResolveReference(route)
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
if download {
|
||||
values.Set(IsDownloadQuery, "true")
|
||||
}
|
||||
|
||||
if thumb {
|
||||
values.Set(IsThumbQuery, "true")
|
||||
}
|
||||
|
||||
route.RawQuery = values.Encode()
|
||||
return route
|
||||
}
|
||||
|
||||
func MasterWopiSrc(base *url.URL, sessionId string) *url.URL {
|
||||
route, _ := url.Parse(constants.APIPrefix + "/file/wopi/" + sessionId)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func SlaveFileContentUrl(base *url.URL, srcPath, name string, download bool, speed int64, nodeId int) *url.URL {
|
||||
srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath)))
|
||||
name = url.PathEscape(name)
|
||||
route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/content/%d/%s/%d/%s", nodeId, srcPath, speed, name))
|
||||
base = base.ResolveReference(route)
|
||||
|
||||
values := url.Values{}
|
||||
if download {
|
||||
values.Set(IsDownloadQuery, "true")
|
||||
}
|
||||
|
||||
base.RawQuery = values.Encode()
|
||||
return base
|
||||
}
|
||||
|
||||
func SlaveMediaMetaRoute(src, ext string) string {
|
||||
src = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(src)))
|
||||
return fmt.Sprintf("file/meta/%s/%s", src, url.PathEscape(ext))
|
||||
}
|
||||
|
||||
func SlaveThumbUrl(base *url.URL, srcPath, ext string) *url.URL {
|
||||
srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath)))
|
||||
ext = url.PathEscape(ext)
|
||||
route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/thumb/%s/%s", srcPath, ext))
|
||||
base = base.ResolveReference(route)
|
||||
return base
|
||||
}
|
||||
|
||||
func SlaveGetTaskRoute(id int, deleteOnComplete bool) string {
|
||||
p := constants.APIPrefixSlave + "/task/" + strconv.Itoa(id)
|
||||
if deleteOnComplete {
|
||||
p += "?" + SlaveClearTaskRegistryQuery + "=true"
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func SlavePingRoute(base *url.URL) string {
|
||||
route, _ := url.Parse(constants.APIPrefixSlave + "/ping")
|
||||
return base.ResolveReference(route).String()
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SlaveNode struct {
|
||||
Model *model.Node
|
||||
Active bool
|
||||
|
||||
caller slaveCaller
|
||||
callback func(bool, uint)
|
||||
close chan bool
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type slaveCaller struct {
|
||||
parent *SlaveNode
|
||||
Client request.Client
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *SlaveNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
|
||||
// Init http request client
|
||||
var endpoint *url.URL
|
||||
if serverURL, err := url.Parse(node.Model.Server); err == nil {
|
||||
var controller *url.URL
|
||||
controller, _ = url.Parse("/api/v3/slave/")
|
||||
endpoint = serverURL.ResolveReference(controller)
|
||||
}
|
||||
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
node.caller.Client = request.NewClient(
|
||||
request.WithMasterMeta(),
|
||||
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
|
||||
request.WithEndpoint(endpoint.String()),
|
||||
)
|
||||
|
||||
node.caller.parent = node
|
||||
if node.close != nil {
|
||||
node.lock.Unlock()
|
||||
node.close <- true
|
||||
go node.StartPingLoop()
|
||||
} else {
|
||||
node.Active = true
|
||||
node.lock.Unlock()
|
||||
go node.StartPingLoop()
|
||||
}
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
|
||||
node.lock.Lock()
|
||||
node.callback = callback
|
||||
node.lock.Unlock()
|
||||
}
|
||||
|
||||
// Ping 从机节点,返回从机负载
|
||||
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
reqBodyEncoded, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
|
||||
resp, err := node.caller.Client.Request(
|
||||
"POST",
|
||||
"heartbeat",
|
||||
bodyReader,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
var res serializer.NodePingResp
|
||||
|
||||
if resStr, ok := resp.Data.(string); ok {
|
||||
err = json.Unmarshal([]byte(resStr), &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *SlaveNode) IsActive() bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Active
|
||||
}
|
||||
|
||||
// Kill 结束节点内相关循环
|
||||
func (node *SlaveNode) Kill() {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if node.close != nil {
|
||||
close(node.close)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取从机Aria2实例
|
||||
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
return &node.caller
|
||||
}
|
||||
|
||||
func (node *SlaveNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *SlaveNode) StartPingLoop() {
|
||||
node.lock.Lock()
|
||||
node.close = make(chan bool)
|
||||
node.lock.Unlock()
|
||||
|
||||
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
|
||||
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
|
||||
pingTicker := time.Duration(0)
|
||||
|
||||
util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
|
||||
retry := 0
|
||||
recoverMode := false
|
||||
isFirstLoop := true
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-time.After(pingTicker):
|
||||
if pingTicker == 0 {
|
||||
pingTicker = tickDuration
|
||||
}
|
||||
|
||||
util.Log().Debug("Slave node %q send ping.", node.Model.Name)
|
||||
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
|
||||
isFirstLoop = false
|
||||
|
||||
if err != nil {
|
||||
util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
|
||||
retry++
|
||||
if retry >= model.GetIntSetting("slave_node_retry", 3) {
|
||||
util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
|
||||
node.changeStatus(false)
|
||||
|
||||
if !recoverMode {
|
||||
// 启动恢复监控循环
|
||||
util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
|
||||
pingTicker = recoverDuration
|
||||
recoverMode = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if recoverMode {
|
||||
util.Log().Debug("Slave node %q recovered.", node.Model.Name)
|
||||
pingTicker = tickDuration
|
||||
recoverMode = false
|
||||
isFirstLoop = true
|
||||
}
|
||||
|
||||
util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
|
||||
node.changeStatus(true)
|
||||
retry = 0
|
||||
}
|
||||
|
||||
case <-node.close:
|
||||
util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) IsMater() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||
return &serializer.NodePingReq{
|
||||
SiteURL: model.GetSiteURL().String(),
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) changeStatus(isActive bool) {
|
||||
node.lock.RLock()
|
||||
id := node.Model.ID
|
||||
if isActive != node.Active {
|
||||
node.lock.RUnlock()
|
||||
node.lock.Lock()
|
||||
node.Active = isActive
|
||||
node.lock.Unlock()
|
||||
node.callback(isActive, id)
|
||||
} else {
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendAria2Call send remote aria2 call to slave node
|
||||
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
|
||||
reqReader, err := getAria2RequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Client.Request(
|
||||
"POST",
|
||||
"aria2/"+scope,
|
||||
reqReader,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
}
|
||||
|
||||
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
GroupOptions: options,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "task")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "status")
|
||||
if err != nil {
|
||||
return rpc.StatusInfo{}, err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
var status rpc.StatusInfo
|
||||
res.GobDecode(&status)
|
||||
|
||||
return status, err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Cancel(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "cancel")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Select(task *model.Download, files []int) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
Files: files,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "select")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) GetConfig() model.Aria2Option {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
return s.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "delete")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
|
||||
reqBodyEncoded, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return strings.NewReader(string(reqBodyEncoded)), nil
|
||||
}
|
||||
|
||||
// RemoteCallback 发送远程存储策略上传回调请求
|
||||
func RemoteCallback(url string, body serializer.UploadCallback) error {
|
||||
callbackBody, err := json.Marshal(struct {
|
||||
Data serializer.UploadCallback `json:"data"`
|
||||
}{
|
||||
Data: body,
|
||||
})
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
|
||||
}
|
||||
|
||||
resp := request.GeneralClient.Request(
|
||||
"POST",
|
||||
url,
|
||||
bytes.NewReader(callbackBody),
|
||||
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
|
||||
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
|
||||
)
|
||||
|
||||
if resp.Err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
|
||||
}
|
||||
|
||||
// 解析回调服务端响应
|
||||
response, err := resp.DecodeResponse()
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
|
||||
return serializer.NewError(serializer.CodeCallbackError, msg, err)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,559 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSlaveNode_InitAndKill(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
n := &SlaveNode{
|
||||
callback: func(b bool, u uint) {
|
||||
|
||||
},
|
||||
}
|
||||
|
||||
a.NotPanics(func() {
|
||||
n.Init(&model.Node{})
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
n.Init(&model.Node{})
|
||||
n.Kill()
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlaveNode_DummyMethods(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Model.ID = 5
|
||||
a.Equal(m.Model.ID, m.ID())
|
||||
a.Equal(m.Model.ID, m.DBModel().ID)
|
||||
|
||||
a.False(m.IsActive())
|
||||
a.False(m.IsMater())
|
||||
|
||||
m.SubscribeStatusChange(func(isActive bool, id uint) {})
|
||||
}
|
||||
|
||||
func TestSlaveNode_IsFeatureEnabled(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.False(m.IsFeatureEnabled("aria2"))
|
||||
a.False(m.IsFeatureEnabled("random"))
|
||||
m.Model.Aria2Enabled = true
|
||||
a.True(m.IsFeatureEnabled("aria2"))
|
||||
}
|
||||
|
||||
func TestSlaveNode_Ping(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
// master return error code
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
a.Equal(1, err.(serializer.AppError).Code)
|
||||
}
|
||||
|
||||
// return unexpected json
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// return success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveNode_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.Model.Aria2Enabled = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
}
|
||||
|
||||
func TestSlaveNode_StartPingLoop(t *testing.T) {
|
||||
callbackCount := 0
|
||||
finishedChan := make(chan struct{})
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m := &SlaveNode{
|
||||
Active: true,
|
||||
Model: &model.Node{},
|
||||
callback: func(b bool, u uint) {
|
||||
callbackCount++
|
||||
if callbackCount == 2 {
|
||||
close(finishedChan)
|
||||
}
|
||||
if callbackCount == 1 {
|
||||
mockRequest.AssertExpectations(t)
|
||||
mockRequest = requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
|
||||
},
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
cache.Set("setting_slave_ping_interval", "0", 0)
|
||||
cache.Set("setting_slave_recover_interval", "0", 0)
|
||||
cache.Set("setting_slave_node_retry", "1", 0)
|
||||
|
||||
m.caller.Client = &mockRequest
|
||||
go func() {
|
||||
select {
|
||||
case <-finishedChan:
|
||||
m.Kill()
|
||||
}
|
||||
}()
|
||||
m.StartPingLoop()
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestSlaveNode_AuthInstance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.MasterAuthInstance())
|
||||
a.NotNil(m.SlaveAuthInstance())
|
||||
}
|
||||
|
||||
func TestSlaveNode_ChangeStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
isActive := false
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
callback: func(b bool, u uint) {
|
||||
isActive = b
|
||||
},
|
||||
}
|
||||
|
||||
a.NotPanics(func() {
|
||||
m.changeStatus(false)
|
||||
})
|
||||
m.changeStatus(true)
|
||||
a.True(isActive)
|
||||
}
|
||||
|
||||
func getTestRPCNodeSlave() *SlaveNode {
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
m.caller.parent = m
|
||||
return m
|
||||
}
|
||||
|
||||
func TestSlaveCaller_CreateTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.CreateTask(&model.Download{}, nil)
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.CreateTask(&model.Download{}, nil)
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.CreateTask(&model.Download{}, nil)
|
||||
a.Equal("res", res)
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_Status(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.Status(&model.Download{})
|
||||
a.Empty(res.Status)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.Status(&model.Download{})
|
||||
a.Empty(res.Status)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.Status(&model.Download{})
|
||||
a.Empty(res.Status)
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_Cancel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Cancel(&model.Download{})
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_Select(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
m.caller.Init()
|
||||
m.caller.GetConfig()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Select(&model.Download{}, nil)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Select(&model.Download{}, nil)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Select(&model.Download{}, nil)
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_DeleteTempFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
m.caller.Init()
|
||||
m.caller.GetConfig()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.DeleteTempFile(&model.Download{})
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteCallback(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
// 回调成功
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.NoError(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 服务端返回业务错误
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.EqualValues(401, resp.(serializer.AppError).Code)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 无法解析回调响应
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// HTTP状态码非200
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 无法发起回调
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: errors.New("error"),
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user