Init V4 community edition (#2265)

* Init V4 community edition

* Init V4 community edition
This commit is contained in:
AaronLiu
2025-04-20 17:31:25 +08:00
committed by GitHub
parent da4e44b77a
commit 21d158db07
597 changed files with 119415 additions and 41692 deletions

View File

@@ -1,209 +0,0 @@
package cluster
import (
"bytes"
"encoding/gob"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
"net/url"
"sync"
)
var DefaultController Controller
// Controller controls communications between master and slave
type Controller interface {
// Handle heartbeat sent from master
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
// Get Aria2 Instance by master node ID
GetAria2Instance(string) (common.Aria2, error)
// Send event change message to master node
SendNotification(string, string, mq.Message) error
// Submit async task into task pool
SubmitTask(string, interface{}, string, func(interface{})) error
// Get master node info
GetMasterInfo(string) (*MasterInfo, error)
// Get master Oauth based policy credential
GetPolicyOauthToken(string, uint) (string, error)
}
type slaveController struct {
masters map[string]MasterInfo
lock sync.RWMutex
}
// info of master node
type MasterInfo struct {
ID string
TTL int
URL *url.URL
// used to invoke aria2 rpc calls
Instance Node
Client request.Client
jobTracker map[string]bool
}
func InitController() {
DefaultController = &slaveController{
masters: make(map[string]MasterInfo),
}
gob.Register(rpc.StatusInfo{})
}
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
c.lock.Lock()
defer c.lock.Unlock()
req.Node.AfterFind()
// close old node if exist
origin, ok := c.masters[req.SiteID]
if (ok && req.IsUpdate) || !ok {
if ok {
origin.Instance.Kill()
}
masterUrl, err := url.Parse(req.SiteURL)
if err != nil {
return serializer.NodePingResp{}, err
}
c.masters[req.SiteID] = MasterInfo{
ID: req.SiteID,
URL: masterUrl,
TTL: req.CredentialTTL,
Client: request.NewClient(
request.WithEndpoint(masterUrl.String()),
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
request.WithCredential(auth.HMACAuth{
SecretKey: []byte(req.Node.MasterKey),
}, int64(req.CredentialTTL)),
),
jobTracker: make(map[string]bool),
Instance: NewNodeFromDBModel(&model.Node{
Model: gorm.Model{ID: req.Node.ID},
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled,
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
}),
}
}
return serializer.NodePingResp{}, nil
}
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return node.Instance.GetAria2Instance(), nil
}
return nil, ErrMasterNotFound
}
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
body := bytes.Buffer{}
enc := gob.NewEncoder(&body)
if err := enc.Encode(&msg); err != nil {
return err
}
res, err := node.Client.Request(
"PUT",
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
&body,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
c.lock.RUnlock()
return ErrMasterNotFound
}
// SubmitTask 提交异步任务
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
if _, ok := node.jobTracker[hash]; ok {
// 任务已存在,直接返回
return nil
}
node.jobTracker[hash] = true
submitter(job)
return nil
}
return ErrMasterNotFound
}
// GetMasterInfo 获取主机节点信息
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return &node, nil
}
return nil, ErrMasterNotFound
}
// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证
func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
res, err := node.Client.Request(
"GET",
fmt.Sprintf("/api/v3/slave/credential/%d", policyID),
nil,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), nil
}
c.lock.RUnlock()
return "", ErrMasterNotFound
}

View File

@@ -1,385 +0,0 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
)
func TestInitController(t *testing.T) {
assert.NotPanics(t, func() {
InitController()
})
}
func TestSlaveController_HandleHeartBeat(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: make(map[string]MasterInfo),
}
// first heart beat
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
Node: &model.Node{},
})
a.NoError(err)
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "2",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
}
// second heart beat, no fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Empty(c.masters["1"].URL)
}
// second heart beat, fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
// second heart beat, fresh, url illegal
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: string([]byte{0x7f}),
Node: &model.Node{},
})
a.Error(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
}
type nodeMock struct {
testMock.Mock
}
func (n nodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n nodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n nodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n nodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n nodeMock) Kill() {
n.Called()
}
func (n nodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
func TestSlaveController_GetAria2Instance(t *testing.T) {
a := assert.New(t)
mockNode := &nodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Instance: mockNode},
},
}
// node node found
{
res, err := c.GetAria2Instance("2")
a.Nil(res)
a.Equal(ErrMasterNotFound, err)
}
// node found
{
res, err := c.GetAria2Instance("1")
a.NotNil(res)
a.NoError(err)
mockNode.AssertExpectations(t)
}
}
type requestMock struct {
testMock.Mock
}
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
return r.Called(method, target, body, opts).Get(0).(*request.Response)
}
func TestSlaveController_SendNotification(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
}
// gob encode error
{
type randomType struct{}
a.Error(c.SendNotification("1", "", mq.Message{
Content: randomType{},
}))
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Error(c.SendNotification("1", "s1", mq.Message{}))
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
mockRequest.AssertExpectations(t)
}
}
func TestSlaveController_SubmitTask(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {
jobTracker: map[string]bool{},
},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
}
// success
{
submitted := false
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
submitted = true
}))
a.True(submitted)
}
// job already submitted
{
submitted := false
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
submitted = true
}))
a.False(submitted)
}
}
func TestSlaveController_GetMasterInfo(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
res, err := c.GetMasterInfo("2")
a.Equal(ErrMasterNotFound, err)
a.Nil(res)
}
// success
{
res, err := c.GetMasterInfo("1")
a.NoError(err)
a.NotNil(res)
}
}
func TestSlaveController_GetOneDriveToken(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
res, err := c.GetPolicyOauthToken("2", 1)
a.Equal(ErrMasterNotFound, err)
a.Empty(res)
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetPolicyOauthToken("1", 1)
a.Error(err)
a.Empty(res)
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetPolicyOauthToken("1", 1)
a.Equal(1, err.(serializer.AppError).Code)
a.Empty(res)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetPolicyOauthToken("1", 1)
a.NoError(err)
a.Equal("expected", res)
mockRequest.AssertExpectations(t)
}
}

View File

@@ -1,12 +0,0 @@
package cluster
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil)
)

View File

@@ -1,272 +0,0 @@
package cluster
import (
"context"
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
const (
deleteTempFileDuration = 60 * time.Second
statusRetryDuration = 10 * time.Second
)
type MasterNode struct {
Model *model.Node
aria2RPC rpcService
lock sync.RWMutex
}
// RPCService 通过RPC服务的Aria2任务管理器
type rpcService struct {
Caller rpc.Client
Initialized bool
retryDuration time.Duration
deletePaddingDuration time.Duration
parent *MasterNode
options *clientOptions
}
type clientOptions struct {
Options map[string]interface{} // 创建下载时额外添加的设置
}
// Init 初始化节点
func (node *MasterNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
node.aria2RPC.parent = node
node.aria2RPC.retryDuration = statusRetryDuration
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
node.lock.Unlock()
node.lock.RLock()
if node.Model.Aria2Enabled {
node.lock.RUnlock()
node.aria2RPC.Init()
return
}
node.lock.RUnlock()
}
func (node *MasterNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
return &serializer.NodePingResp{}, nil
}
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
}
func (node *MasterNode) MasterAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
}
// IsActive 返回节点是否在线
func (node *MasterNode) IsActive() bool {
return true
}
// Kill 结束aria2请求
func (node *MasterNode) Kill() {
if node.aria2RPC.Caller != nil {
node.aria2RPC.Caller.Close()
}
}
// GetAria2Instance 获取主机Aria2实例
func (node *MasterNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
if !node.Model.Aria2Enabled {
node.lock.RUnlock()
return &common.DummyAria2{}
}
if !node.aria2RPC.Initialized {
node.lock.RUnlock()
node.aria2RPC.Init()
return &common.DummyAria2{}
}
defer node.lock.RUnlock()
return &node.aria2RPC
}
func (node *MasterNode) IsMater() bool {
return true
}
func (node *MasterNode) DBModel() *model.Node {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model
}
func (r *rpcService) Init() error {
r.parent.lock.Lock()
defer r.parent.lock.Unlock()
r.Initialized = false
// 客户端已存在,则关闭先前连接
if r.Caller != nil {
r.Caller.Close()
}
// 解析RPC服务地址
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
if err != nil {
util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
return err
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
if err != nil {
util.Log().Warning("Failed to parse aria2 options: %s", err)
return err
}
}
r.options = &clientOptions{
Options: globalOptions,
}
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
r.Caller = caller
r.Initialized = err == nil
return err
}
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
r.parent.lock.RLock()
// 生成存储路径
guid, _ := uuid.NewV4()
path := filepath.Join(
r.parent.Model.Aria2OptionsSerialized.TempPath,
"aria2",
guid.String(),
)
r.parent.lock.RUnlock()
// 创建下载任务
options := map[string]interface{}{
"dir": path,
}
for k, v := range r.options.Options {
options[k] = v
}
for k, v := range groupOptions {
options[k] = v
}
gid, err := r.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return "", err
}
return gid, nil
}
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := r.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("Failed to get download task status, please retry later: %s", err)
time.Sleep(r.retryDuration)
res, err = r.Caller.TellStatus(task.GID)
}
return res, err
}
func (r *rpcService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := r.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
}
return err
}
func (r *rpcService) Select(task *model.Download, files []int) error {
var selected = make([]string, len(files))
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
func (r *rpcService) GetConfig() model.Aria2Option {
r.parent.lock.RLock()
defer r.parent.lock.RUnlock()
return r.parent.Model.Aria2OptionsSerialized
}
func (s *rpcService) DeleteTempFile(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
// 避免被aria2占用异步执行删除
go func(d time.Duration, src string) {
time.Sleep(d)
err := os.RemoveAll(src)
if err != nil {
util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
}
}(s.deletePaddingDuration, task.Parent)
return nil
}

View File

@@ -1,186 +0,0 @@
package cluster
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/stretchr/testify/assert"
"os"
"testing"
"time"
)
func TestMasterNode_Init(t *testing.T) {
a := assert.New(t)
m := &MasterNode{}
m.Init(&model.Node{Status: model.NodeSuspend})
a.Equal(model.NodeSuspend, m.DBModel().Status)
m.Init(&model.Node{Aria2Enabled: true})
}
func TestMasterNode_DummyMethods(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
m.Model.ID = 5
a.Equal(m.Model.ID, m.ID())
res, err := m.Ping(&serializer.NodePingReq{})
a.NoError(err)
a.NotNil(res)
a.True(m.IsActive())
a.True(m.IsMater())
m.SubscribeStatusChange(func(isActive bool, id uint) {})
}
func TestMasterNode_IsFeatureEnabled(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
a.False(m.IsFeatureEnabled("aria2"))
a.False(m.IsFeatureEnabled("random"))
m.Model.Aria2Enabled = true
a.True(m.IsFeatureEnabled("aria2"))
}
func TestMasterNode_AuthInstance(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
a.NotNil(m.MasterAuthInstance())
a.NotNil(m.SlaveAuthInstance())
}
func TestMasterNode_Kill(t *testing.T) {
m := &MasterNode{
Model: &model.Node{},
}
m.Kill()
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
m.Kill()
}
func TestMasterNode_GetAria2Instance(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
aria2RPC: rpcService{},
}
m.aria2RPC.parent = m
a.NotNil(m.GetAria2Instance())
m.Model.Aria2Enabled = true
a.NotNil(m.GetAria2Instance())
m.aria2RPC.Initialized = true
a.NotNil(m.GetAria2Instance())
}
func TestRpcService_Init(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{
Aria2OptionsSerialized: model.Aria2Option{
Options: "{",
},
},
aria2RPC: rpcService{},
}
m.aria2RPC.parent = m
// failed to decode address
{
m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
a.Error(m.aria2RPC.Init())
}
// failed to decode options
{
m.Model.Aria2OptionsSerialized.Server = ""
a.Error(m.aria2RPC.Init())
}
// failed to initialized
{
m.Model.Aria2OptionsSerialized.Server = ""
m.Model.Aria2OptionsSerialized.Options = "{}"
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
a.Error(m.aria2RPC.Init())
a.False(m.aria2RPC.Initialized)
}
}
func getTestRPCNode() *MasterNode {
m := &MasterNode{
Model: &model.Node{
Aria2OptionsSerialized: model.Aria2Option{},
},
aria2RPC: rpcService{
options: &clientOptions{
Options: map[string]interface{}{"1": "1"},
},
},
}
m.aria2RPC.parent = m
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
return m
}
func TestRpcService_CreateTask(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
a.Error(err)
a.Empty(res)
}
func TestRpcService_Status(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
res, err := m.aria2RPC.Status(&model.Download{})
a.Error(err)
a.Empty(res)
}
func TestRpcService_Cancel(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
a.Error(m.aria2RPC.Cancel(&model.Download{}))
}
func TestRpcService_Select(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
a.NotNil(m.aria2RPC.GetConfig())
a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
}
func TestRpcService_DeleteTempFile(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
fdName := "TestRpcService_DeleteTempFile"
a.NoError(os.Mkdir(fdName, 0644))
a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
time.Sleep(500 * time.Millisecond)
a.False(util.Exists(fdName))
}

View File

@@ -1,60 +1,413 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/node"
"github.com/cloudreve/Cloudreve/v4/ent/task"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/aria2"
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/qbittorrent"
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/queue"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"strconv"
)
type Node interface {
// Init a node from database model
Init(node *model.Node)
type (
Node interface {
fs.StatelessUploadManager
ID() int
Name() string
IsMaster() bool
// CreateTask creates a task on the node. It does not have effect on master node.
CreateTask(ctx context.Context, taskType string, state string) (int, error)
// GetTask returns the task summary of the task with the given id.
GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error)
// CleanupFolders cleans up the given folders on the node.
CleanupFolders(ctx context.Context, folders ...string) error
// AuthInstance returns the auth instance for the node.
AuthInstance() auth.Auth
// CreateDownloader creates a downloader instance from the node for remote download tasks.
CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error)
// Settings returns the settings of the node.
Settings(ctx context.Context) *types.NodeSetting
}
// Check if given feature is enabled
IsFeatureEnabled(feature string) bool
// Request body for creating tasks on slave node
CreateSlaveTask struct {
Type string `json:"type"`
State string `json:"state"`
}
// Subscribe node status change to a callback function
SubscribeStatusChange(callback func(isActive bool, id uint))
// Request body for cleaning up folders on slave node
FolderCleanup struct {
Path []string `json:"path" binding:"required"`
}
// Ping the node
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
SlaveTaskSummary struct {
Status task.Status `json:"status"`
Error string `json:"error"`
PrivateState string `json:"private_state"`
Progress queue.Progresses `json:"progress,omitempty"`
}
// Returns if the node is active
IsActive() bool
MasterSiteUrlCtx struct{}
MasterSiteVersionCtx struct{}
MasterSiteIDCtx struct{}
SlaveNodeIDCtx struct{}
masterNode struct {
nodeBase
client request.Client
}
)
// Get instances for aria2 calls
GetAria2Instance() common.Aria2
// Returns unique id of this node
ID() uint
// Kill node and recycle resources
Kill()
// Returns if current node is master node
IsMater() bool
// Get auth instance used to check RPC call from slave to master
MasterAuthInstance() auth.Auth
// Get auth instance used to check RPC call from master to slave
SlaveAuthInstance() auth.Auth
// Get node DB model
DBModel() *model.Node
func newNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) Node {
if model.Type == node.TypeMaster {
return newMasterNode(model, config, settings)
}
return newSlaveNode(ctx, model, config, settings)
}
// Create new node from DB model
func NewNodeFromDBModel(node *model.Node) Node {
switch node.Type {
case model.SlaveNodeType:
slave := &SlaveNode{}
slave.Init(node)
return slave
default:
master := &MasterNode{}
master.Init(node)
return master
func newMasterNode(model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *masterNode {
n := &masterNode{
nodeBase: nodeBase{
model: model,
},
}
if config.System().Mode == conf.SlaveMode {
n.client = request.NewClient(config,
request.WithCorrelationID(),
request.WithCredential(auth.HMACAuth{
[]byte(config.Slave().Secret),
}, int64(config.Slave().SignatureTTL)),
)
}
return n
}
func (b *masterNode) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) {
reqBody, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "prepare")
resp, err := b.client.Request(
"PUT",
requestDst.String(),
bytes.NewReader(reqBody),
request.WithContext(ctx),
request.WithSlaveMeta(NodeIdFromContext(ctx)),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return nil, err
}
// 处理列取结果
if resp.Code != 0 {
return nil, serializer.NewErrorFromResponse(resp)
}
uploadRequest := &fs.StatelessPrepareUploadResponse{}
resp.GobDecode(uploadRequest)
return uploadRequest, nil
}
func (b *masterNode) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error {
reqBody, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "complete")
resp, err := b.client.Request(
"POST",
requestDst.String(),
bytes.NewReader(reqBody),
request.WithContext(ctx),
request.WithSlaveMeta(NodeIdFromContext(ctx)),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
// 处理列取结果
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (b *masterNode) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error {
reqBody, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "failed")
resp, err := b.client.Request(
"POST",
requestDst.String(),
bytes.NewReader(reqBody),
request.WithContext(ctx),
request.WithSlaveMeta(NodeIdFromContext(ctx)),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
// 处理列取结果
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (b *masterNode) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error {
reqBody, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "create")
resp, err := b.client.Request(
"POST",
requestDst.String(),
bytes.NewReader(reqBody),
request.WithContext(ctx),
request.WithSlaveMeta(NodeIdFromContext(ctx)),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
// 处理列取结果
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (b *masterNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
return NewDownloader(ctx, c, settings, b.Settings(ctx))
}
// NewDownloader creates a new downloader instance from the node for remote download tasks.
func NewDownloader(ctx context.Context, c request.Client, settings setting.Provider, options *types.NodeSetting) (downloader.Downloader, error) {
if options.Provider == types.DownloaderProviderQBittorrent {
return qbittorrent.NewClient(logging.FromContext(ctx), c, settings, options.QBittorrentSetting)
} else if options.Provider == types.DownloaderProviderAria2 {
return aria2.New(logging.FromContext(ctx), settings, options.Aria2Setting), nil
} else if options.Provider == "" {
return nil, errors.New("downloader not configured for this node")
} else {
return nil, errors.New("unknown downloader provider")
}
}
type slaveNode struct {
nodeBase
client request.Client
}
func newSlaveNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *slaveNode {
siteBasic := settings.SiteBasic(ctx)
return &slaveNode{
nodeBase: nodeBase{
model: model,
},
client: request.NewClient(config,
request.WithCorrelationID(),
request.WithSlaveMeta(model.ID),
request.WithMasterMeta(siteBasic.ID, settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()),
request.WithCredential(auth.HMACAuth{[]byte(model.SlaveKey)}, int64(settings.SlaveRequestSignTTL(ctx))),
request.WithEndpoint(model.Server)),
}
}
func (n *slaveNode) CreateTask(ctx context.Context, taskType string, state string) (int, error) {
reqBody, err := json.Marshal(&CreateSlaveTask{
Type: taskType,
State: state,
})
if err != nil {
return 0, fmt.Errorf("failed to marshal request body: %w", err)
}
resp, err := n.client.Request(
"PUT",
constants.APIPrefixSlave+"/task",
bytes.NewReader(reqBody),
request.WithContext(ctx),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return 0, err
}
// 处理列取结果
if resp.Code != 0 {
return 0, serializer.NewErrorFromResponse(resp)
}
taskId := 0
if resp.GobDecode(&taskId); taskId > 0 {
return taskId, nil
}
return 0, fmt.Errorf("unexpected response data: %v", resp.Data)
}
func (n *slaveNode) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) {
resp, err := n.client.Request(
"GET",
routes.SlaveGetTaskRoute(id, clearOnComplete),
nil,
request.WithContext(ctx),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return nil, err
}
// 处理列取结果
if resp.Code != 0 {
return nil, serializer.NewErrorFromResponse(resp)
}
summary := &SlaveTaskSummary{}
resp.GobDecode(summary)
return summary, nil
}
func (b *slaveNode) CleanupFolders(ctx context.Context, folders ...string) error {
args := &FolderCleanup{
Path: folders,
}
reqBody, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
resp, err := b.client.Request(
"POST",
constants.APIPrefixSlave+"/task/cleanup",
bytes.NewReader(reqBody),
request.WithContext(ctx),
request.WithLogger(logging.FromContext(ctx)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
// 处理列取结果
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (b *slaveNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
return slave.NewSlaveDownloader(b.client, b.Settings(ctx)), nil
}
type nodeBase struct {
model *ent.Node
}
func (b *nodeBase) ID() int {
return b.model.ID
}
func (b *nodeBase) Name() string {
return b.model.Name
}
func (b *nodeBase) IsMaster() bool {
return b.model.Type == node.TypeMaster
}
func (b *nodeBase) CreateTask(ctx context.Context, taskType string, state string) (int, error) {
return 0, errors.New("not implemented")
}
func (b *nodeBase) AuthInstance() auth.Auth {
return auth.HMACAuth{[]byte(b.model.SlaveKey)}
}
func (b *nodeBase) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) {
return nil, errors.New("not implemented")
}
func (b *nodeBase) CleanupFolders(ctx context.Context, folders ...string) error {
return errors.New("not implemented")
}
func (b *nodeBase) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) {
return nil, errors.New("not implemented")
}
func (b *nodeBase) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error {
return errors.New("not implemented")
}
func (b *nodeBase) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error {
return errors.New("not implemented")
}
func (b *nodeBase) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error {
return errors.New("not implemented")
}
func (b *nodeBase) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
return nil, errors.New("not implemented")
}
func (b *nodeBase) Settings(ctx context.Context) *types.NodeSetting {
return b.model.Settings
}
func NodeIdFromContext(ctx context.Context) int {
nodeIdStr, ok := ctx.Value(SlaveNodeIDCtx{}).(string)
if !ok {
return 0
}
nodeId, _ := strconv.Atoi(nodeIdStr)
return nodeId
}
func MasterSiteUrlFromContext(ctx context.Context) string {
if u, ok := ctx.Value(MasterSiteUrlCtx{}).(string); ok {
return u
}
return ""
}

View File

@@ -1,17 +0,0 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewNodeFromDBModel(t *testing.T) {
a := assert.New(t)
a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
Type: model.SlaveNodeType,
}))
a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
Type: model.MasterNodeType,
}))
}

View File

@@ -1,190 +1,203 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"context"
"fmt"
"sync"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/node"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/samber/lo"
)
var Default *NodePool
// 需要分类的节点组
var featureGroup = []string{"aria2"}
// Pool 节点池
type Pool interface {
// Returns active node selected by given feature and load balancer
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
// Returns node by ID
GetNodeByID(id uint) Node
// Add given node into pool. If node existed, refresh node.
Add(node *model.Node)
// Delete and kill node from pool by given node id
Delete(id uint)
type NodePool interface {
// Upsert updates or inserts a node into the pool.
Upsert(ctx context.Context, node *ent.Node)
// Get returns a node with the given capability and preferred node id. `allowed` is a list of allowed node ids.
// If `allowed` is empty, all nodes with the capability are considered.
Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error)
}
// NodePool 通用节点池
type NodePool struct {
active map[uint]Node
inactive map[uint]Node
type (
weightedNodePool struct {
lock sync.RWMutex
featureMap map[string][]Node
conf conf.ConfigProvider
settings setting.Provider
lock sync.RWMutex
}
// Init 初始化从机节点池
func Init() {
Default = &NodePool{}
Default.Init()
if err := Default.initFromDB(); err != nil {
util.Log().Warning("Failed to initialize node pool: %s", err)
}
}
func (pool *NodePool) Init() {
pool.lock.Lock()
defer pool.lock.Unlock()
pool.featureMap = make(map[string][]Node)
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
}
func (pool *NodePool) buildIndexMap() {
pool.lock.Lock()
for _, feature := range featureGroup {
pool.featureMap[feature] = make([]Node, 0)
nodes map[types.NodeCapability][]*nodeItem
}
for _, v := range pool.active {
for _, feature := range featureGroup {
if v.IsFeatureEnabled(feature) {
pool.featureMap[feature] = append(pool.featureMap[feature], v)
nodeItem struct {
node Node
weight int
current int
}
)
var (
ErrNoAvailableNode = fmt.Errorf("no available node found")
supportedCapabilities = []types.NodeCapability{
types.NodeCapabilityNone,
types.NodeCapabilityCreateArchive,
types.NodeCapabilityExtractArchive,
types.NodeCapabilityRemoteDownload,
}
)
func NewNodePool(ctx context.Context, l logging.Logger, config conf.ConfigProvider, settings setting.Provider,
client inventory.NodeClient) (NodePool, error) {
nodes, err := client.ListActiveNodes(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to list active nodes: %w", err)
}
pool := &weightedNodePool{
nodes: make(map[types.NodeCapability][]*nodeItem),
conf: config,
settings: settings,
}
for _, node := range nodes {
for _, capability := range supportedCapabilities {
// If current capability is enabled, add it to pool slot.
if capability == types.NodeCapabilityNone ||
(node.Capabilities != nil && node.Capabilities.Enabled(int(capability))) {
if _, ok := pool.nodes[capability]; !ok {
pool.nodes[capability] = make([]*nodeItem, 0)
}
l.Debug("Add node %q to capability slot %d with weight %d", node.Name, capability, node.Weight)
pool.nodes[capability] = append(pool.nodes[capability], &nodeItem{
node: newNode(ctx, node, config, settings),
weight: node.Weight,
current: 0,
})
}
}
}
pool.lock.Unlock()
return pool, nil
}
func (pool *NodePool) GetNodeByID(id uint) Node {
pool.lock.RLock()
defer pool.lock.RUnlock()
func (p *weightedNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) {
l := logging.FromContext(ctx)
p.lock.RLock()
defer p.lock.RUnlock()
if node, ok := pool.active[id]; ok {
return node
nodes, ok := p.nodes[capability]
if !ok || len(nodes) == 0 {
return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode)
}
return pool.inactive[id]
}
var selected *nodeItem
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive)
var node Node
pool.lock.Lock()
if n, ok := pool.inactive[id]; ok {
node = n
delete(pool.inactive, id)
} else {
node = pool.active[id]
delete(pool.active, id)
}
if isActive {
pool.active[id] = node
} else {
pool.inactive[id] = node
}
pool.lock.Unlock()
pool.buildIndexMap()
}
func (pool *NodePool) initFromDB() error {
nodes, err := model.GetNodesByStatus(model.NodeActive)
if err != nil {
return err
}
pool.lock.Lock()
for i := 0; i < len(nodes); i++ {
pool.add(&nodes[i])
}
pool.lock.Unlock()
pool.buildIndexMap()
return nil
}
func (pool *NodePool) add(node *model.Node) {
newNode := NewNodeFromDBModel(node)
if newNode.IsActive() {
pool.active[node.ID] = newNode
} else {
pool.inactive[node.ID] = newNode
}
// 订阅节点状态变更
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
pool.nodeStatusChange(isActive, id)
})
}
func (pool *NodePool) Add(node *model.Node) {
pool.lock.Lock()
defer pool.buildIndexMap()
defer pool.lock.Unlock()
var (
old Node
ok bool
)
if old, ok = pool.active[node.ID]; !ok {
old, ok = pool.inactive[node.ID]
}
if old != nil {
go old.Init(node)
return
}
pool.add(node)
}
func (pool *NodePool) Delete(id uint) {
pool.lock.Lock()
defer pool.buildIndexMap()
defer pool.lock.Unlock()
if node, ok := pool.active[id]; ok {
node.Kill()
delete(pool.active, id)
return
}
if node, ok := pool.inactive[id]; ok {
node.Kill()
delete(pool.inactive, id)
return
}
}
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
pool.lock.RLock()
defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok {
err, res := lb.NextPeer(nodes)
if err == nil {
return nil, res.(Node)
if preferred > 0 {
// First try to find the preferred node.
for _, n := range nodes {
if n.node.ID() == preferred {
selected = n
break
}
}
return err, nil
if selected == nil {
l.Debug("Preferred node %d not found, fallback to select a node with the least current weight", preferred)
}
}
return ErrFeatureNotExist, nil
if selected == nil {
// If no preferred one, or the preferred one is not available, select a node with the least current weight.
// Total weight of all items.
var total int
// Loop through the list of items and add the item's weight to the current weight.
// Also increment the total weight counter.
var maxNode *nodeItem
for _, item := range nodes {
item.current += max(1, item.weight)
total += max(1, item.weight)
// Select the item with max weight.
if maxNode == nil || item.current > maxNode.current {
maxNode = item
}
}
// Select the item with the max weight.
selected = maxNode
if selected == nil {
return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode)
}
l.Debug("Selected node %q with weight=%d, current=%d, total=%d", selected.node.Name(), selected.weight, maxNode.current, total)
// Reduce the current weight of the selected item by the total weight.
maxNode.current -= total
}
return selected.node, nil
}
func (p *weightedNodePool) Upsert(ctx context.Context, n *ent.Node) {
p.lock.Lock()
defer p.lock.Unlock()
for _, capability := range supportedCapabilities {
_, index, found := lo.FindIndexOf(p.nodes[capability], func(i *nodeItem) bool {
return i.node.ID() == n.ID
})
if capability == types.NodeCapabilityNone ||
(n.Capabilities != nil && n.Capabilities.Enabled(int(capability))) {
if n.Status != node.StatusActive && found {
// Remove inactive node
p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...)
continue
}
if found {
p.nodes[capability][index].node = newNode(ctx, n, p.conf, p.settings)
} else {
p.nodes[capability] = append(p.nodes[capability], &nodeItem{
node: newNode(ctx, n, p.conf, p.settings),
weight: n.Weight,
current: 0,
})
}
} else if found {
// Capability changed, remove the old node.
p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...)
}
}
}
type slaveDummyNodePool struct {
conf conf.ConfigProvider
settings setting.Provider
masterNode Node
}
func NewSlaveDummyNodePool(ctx context.Context, config conf.ConfigProvider, settings setting.Provider) NodePool {
return &slaveDummyNodePool{
conf: config,
settings: settings,
masterNode: newNode(ctx, &ent.Node{
ID: 0,
Name: "Master",
Type: node.TypeMaster,
}, config, settings),
}
}
func (s *slaveDummyNodePool) Upsert(ctx context.Context, node *ent.Node) {
}
func (s *slaveDummyNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) {
return s.masterNode, nil
}

View File

@@ -1,161 +0,0 @@
package cluster
import (
"database/sql"
"errors"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestInitFailed(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
Init()
a.NoError(mock.ExpectationsWereMet())
}
func TestInitSuccess(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType))
Init()
a.NoError(mock.ExpectationsWereMet())
}
func TestNodePool_GetNodeByID(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
mockNode := &nodeMock{}
// inactive
{
p.inactive[1] = mockNode
a.Equal(mockNode, p.GetNodeByID(1))
}
// active
{
delete(p.inactive, 1)
p.active[1] = mockNode
a.Equal(mockNode, p.GetNodeByID(1))
}
}
func TestNodePool_NodeStatusChange(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
n := &MasterNode{Model: &model.Node{}}
p.Init()
p.inactive[1] = n
p.nodeStatusChange(true, 1)
a.Len(p.inactive, 0)
a.Equal(n, p.active[1])
p.nodeStatusChange(false, 1)
a.Len(p.active, 0)
a.Equal(n, p.inactive[1])
p.nodeStatusChange(false, 1)
a.Len(p.active, 0)
a.Equal(n, p.inactive[1])
}
func TestNodePool_Add(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// new node
{
p.Add(&model.Node{})
a.Len(p.active, 1)
}
// old node
{
p.inactive[0] = p.active[0]
delete(p.active, 0)
p.Add(&model.Node{})
a.Len(p.active, 0)
a.Len(p.inactive, 1)
}
}
func TestNodePool_Delete(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// active
{
mockNode := &nodeMock{}
mockNode.On("Kill")
p.active[0] = mockNode
p.Delete(0)
a.Len(p.active, 0)
a.Len(p.inactive, 0)
mockNode.AssertExpectations(t)
}
p.Init()
// inactive
{
mockNode := &nodeMock{}
mockNode.On("Kill")
p.inactive[0] = mockNode
p.Delete(0)
a.Len(p.active, 0)
a.Len(p.inactive, 0)
mockNode.AssertExpectations(t)
}
}
func TestNodePool_BalanceNodeByFeature(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// success
{
p.featureMap["test"] = []Node{&MasterNode{}}
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
a.NoError(err)
a.Equal(p.featureMap["test"][0], res)
}
// NoNodes
{
p.featureMap["test"] = []Node{}
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
a.Error(err)
a.Nil(res)
}
// No match feature
{
err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin"))
a.Error(err)
a.Nil(res)
}
}

View File

@@ -0,0 +1,207 @@
package routes
import (
"encoding/base64"
"fmt"
"net/url"
"path"
"strconv"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
)
const (
IsDownloadQuery = "download"
IsThumbQuery = "thumb"
SlaveClearTaskRegistryQuery = "deleteOnComplete"
)
var (
masterPing *url.URL
masterUserActivate *url.URL
masterUserReset *url.URL
masterHome *url.URL
)
func init() {
masterPing, _ = url.Parse(constants.APIPrefix + "/site/ping")
masterUserActivate, _ = url.Parse("/session/activate")
masterUserReset, _ = url.Parse("/session/reset")
}
func FrontendHomeUrl(base *url.URL, path string) *url.URL {
route, _ := url.Parse(fmt.Sprintf("/home"))
q := route.Query()
q.Set("path", path)
route.RawQuery = q.Encode()
return base.ResolveReference(route)
}
func MasterPingUrl(base *url.URL) *url.URL {
return base.ResolveReference(masterPing)
}
func MasterSlaveCallbackUrl(base *url.URL, driver, id, secret string) *url.URL {
apiBaseURI, _ := url.Parse(path.Join(constants.APIPrefix+"/callback", driver, id, secret))
return base.ResolveReference(apiBaseURI)
}
func MasterUserActivateAPIUrl(base *url.URL, uid string) *url.URL {
route, _ := url.Parse(constants.APIPrefix + "/user/activate/" + uid)
return base.ResolveReference(route)
}
func MasterUserActivateUrl(base *url.URL) *url.URL {
return base.ResolveReference(masterUserActivate)
}
func MasterUserResetUrl(base *url.URL) *url.URL {
return base.ResolveReference(masterUserReset)
}
func MasterShareUrl(base *url.URL, id, password string) *url.URL {
p := "/s/" + id
if password != "" {
p += ("/" + password)
}
route, _ := url.Parse(p)
return base.ResolveReference(route)
}
func MasterDirectLink(base *url.URL, id, name string) *url.URL {
p := path.Join("/f", id, url.PathEscape(name))
route, _ := url.Parse(p)
return base.ResolveReference(route)
}
// MasterShareLongUrl generates a long share URL for redirect.
func MasterShareLongUrl(id, password string) *url.URL {
base, _ := url.Parse("/home")
q := base.Query()
q.Set("path", fs.NewShareUri(id, password))
base.RawQuery = q.Encode()
return base
}
func MasterArchiveDownloadUrl(base *url.URL, sessionID string) *url.URL {
routes, err := url.Parse(path.Join(constants.APIPrefix, "file", "archive", sessionID, "archive.zip"))
if err != nil {
return nil
}
return base.ResolveReference(routes)
}
func MasterPolicyOAuthCallback(base *url.URL) *url.URL {
if base.Scheme != "https" {
base.Scheme = "https"
}
routes, err := url.Parse("/admin/policy/oauth")
if err != nil {
return nil
}
return base.ResolveReference(routes)
}
func MasterGetCredentialUrl(base, key string) *url.URL {
masterBase, err := url.Parse(base)
if err != nil {
return nil
}
routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "credential", key))
if err != nil {
return nil
}
return masterBase.ResolveReference(routes)
}
func MasterStatelessUrl(base, method string) *url.URL {
masterBase, err := url.Parse(base)
if err != nil {
return nil
}
routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "statelessUpload", method))
if err != nil {
return nil
}
return masterBase.ResolveReference(routes)
}
func SlaveUploadUrl(base *url.URL, sessionID string) *url.URL {
base.Path = path.Join(base.Path, constants.APIPrefixSlave, "/upload", sessionID)
return base
}
func MasterFileContentUrl(base *url.URL, entityId, name string, download, thumb bool, speed int64) *url.URL {
name = url.PathEscape(name)
route, _ := url.Parse(constants.APIPrefix + fmt.Sprintf("/file/content/%s/%d/%s", entityId, speed, name))
if base != nil {
route = base.ResolveReference(route)
}
values := url.Values{}
if download {
values.Set(IsDownloadQuery, "true")
}
if thumb {
values.Set(IsThumbQuery, "true")
}
route.RawQuery = values.Encode()
return route
}
func MasterWopiSrc(base *url.URL, sessionId string) *url.URL {
route, _ := url.Parse(constants.APIPrefix + "/file/wopi/" + sessionId)
return base.ResolveReference(route)
}
func SlaveFileContentUrl(base *url.URL, srcPath, name string, download bool, speed int64, nodeId int) *url.URL {
srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath)))
name = url.PathEscape(name)
route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/content/%d/%s/%d/%s", nodeId, srcPath, speed, name))
base = base.ResolveReference(route)
values := url.Values{}
if download {
values.Set(IsDownloadQuery, "true")
}
base.RawQuery = values.Encode()
return base
}
func SlaveMediaMetaRoute(src, ext string) string {
src = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(src)))
return fmt.Sprintf("file/meta/%s/%s", src, url.PathEscape(ext))
}
func SlaveThumbUrl(base *url.URL, srcPath, ext string) *url.URL {
srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath)))
ext = url.PathEscape(ext)
route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/thumb/%s/%s", srcPath, ext))
base = base.ResolveReference(route)
return base
}
func SlaveGetTaskRoute(id int, deleteOnComplete bool) string {
p := constants.APIPrefixSlave + "/task/" + strconv.Itoa(id)
if deleteOnComplete {
p += "?" + SlaveClearTaskRegistryQuery + "=true"
}
return p
}
func SlavePingRoute(base *url.URL) string {
route, _ := url.Parse(constants.APIPrefixSlave + "/ping")
return base.ResolveReference(route).String()
}

View File

@@ -1,451 +0,0 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"net/url"
"strings"
"sync"
"time"
)
type SlaveNode struct {
Model *model.Node
Active bool
caller slaveCaller
callback func(bool, uint)
close chan bool
lock sync.RWMutex
}
type slaveCaller struct {
parent *SlaveNode
Client request.Client
}
// Init 初始化节点
func (node *SlaveNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
// Init http request client
var endpoint *url.URL
if serverURL, err := url.Parse(node.Model.Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave/")
endpoint = serverURL.ResolveReference(controller)
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
node.caller.Client = request.NewClient(
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
request.WithEndpoint(endpoint.String()),
)
node.caller.parent = node
if node.close != nil {
node.lock.Unlock()
node.close <- true
go node.StartPingLoop()
} else {
node.Active = true
node.lock.Unlock()
go node.StartPingLoop()
}
}
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
node.lock.Lock()
node.callback = callback
node.lock.Unlock()
}
// Ping 从机节点,返回从机负载
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
node.lock.RLock()
defer node.lock.RUnlock()
reqBodyEncoded, err := json.Marshal(req)
if err != nil {
return nil, err
}
bodyReader := strings.NewReader(string(reqBodyEncoded))
resp, err := node.caller.Client.Request(
"POST",
"heartbeat",
bodyReader,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return nil, err
}
// 处理列取结果
if resp.Code != 0 {
return nil, serializer.NewErrorFromResponse(resp)
}
var res serializer.NodePingResp
if resStr, ok := resp.Data.(string); ok {
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return nil, err
}
}
return &res, nil
}
// IsActive 返回节点是否在线
func (node *SlaveNode) IsActive() bool {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Active
}
// Kill 结束节点内相关循环
func (node *SlaveNode) Kill() {
node.lock.RLock()
defer node.lock.RUnlock()
if node.close != nil {
close(node.close)
}
}
// GetAria2Instance 获取从机Aria2实例
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
defer node.lock.RUnlock()
if !node.Model.Aria2Enabled {
return &common.DummyAria2{}
}
return &node.caller
}
func (node *SlaveNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *SlaveNode) StartPingLoop() {
node.lock.Lock()
node.close = make(chan bool)
node.lock.Unlock()
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
pingTicker := time.Duration(0)
util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
retry := 0
recoverMode := false
isFirstLoop := true
loop:
for {
select {
case <-time.After(pingTicker):
if pingTicker == 0 {
pingTicker = tickDuration
}
util.Log().Debug("Slave node %q send ping.", node.Model.Name)
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
isFirstLoop = false
if err != nil {
util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
retry++
if retry >= model.GetIntSetting("slave_node_retry", 3) {
util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
node.changeStatus(false)
if !recoverMode {
// 启动恢复监控循环
util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
pingTicker = recoverDuration
recoverMode = true
}
}
} else {
if recoverMode {
util.Log().Debug("Slave node %q recovered.", node.Model.Name)
pingTicker = tickDuration
recoverMode = false
isFirstLoop = true
}
util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
node.changeStatus(true)
retry = 0
}
case <-node.close:
util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
break loop
}
}
}
func (node *SlaveNode) IsMater() bool {
return false
}
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
}
func (node *SlaveNode) DBModel() *model.Node {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model
}
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{
SiteURL: model.GetSiteURL().String(),
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
}
}
func (node *SlaveNode) changeStatus(isActive bool) {
node.lock.RLock()
id := node.Model.ID
if isActive != node.Active {
node.lock.RUnlock()
node.lock.Lock()
node.Active = isActive
node.lock.Unlock()
node.callback(isActive, id)
} else {
node.lock.RUnlock()
}
}
func (s *slaveCaller) Init() error {
return nil
}
// SendAria2Call send remote aria2 call to slave node
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
reqReader, err := getAria2RequestBody(body)
if err != nil {
return nil, err
}
return s.Client.Request(
"POST",
"aria2/"+scope,
reqReader,
).CheckHTTPResponse(200).DecodeResponse()
}
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
GroupOptions: options,
}
res, err := s.SendAria2Call(req, "task")
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), err
}
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "status")
if err != nil {
return rpc.StatusInfo{}, err
}
if res.Code != 0 {
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
}
var status rpc.StatusInfo
res.GobDecode(&status)
return status, err
}
func (s *slaveCaller) Cancel(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "cancel")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) Select(task *model.Download, files []int) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
Files: files,
}
res, err := s.SendAria2Call(req, "select")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) GetConfig() model.Aria2Option {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
return s.parent.Model.Aria2OptionsSerialized
}
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "delete")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
reqBodyEncoded, err := json.Marshal(body)
if err != nil {
return nil, err
}
return strings.NewReader(string(reqBodyEncoded)), nil
}
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
}
resp := request.GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
}
// 解析回调服务端响应
response, err := resp.DecodeResponse()
if err != nil {
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
return serializer.NewError(serializer.CodeCallbackError, msg, err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

View File

@@ -1,559 +0,0 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
)
func TestSlaveNode_InitAndKill(t *testing.T) {
a := assert.New(t)
n := &SlaveNode{
callback: func(b bool, u uint) {
},
}
a.NotPanics(func() {
n.Init(&model.Node{})
time.Sleep(time.Millisecond * 500)
n.Init(&model.Node{})
n.Kill()
})
}
func TestSlaveNode_DummyMethods(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
m.Model.ID = 5
a.Equal(m.Model.ID, m.ID())
a.Equal(m.Model.ID, m.DBModel().ID)
a.False(m.IsActive())
a.False(m.IsMater())
m.SubscribeStatusChange(func(isActive bool, id uint) {})
}
func TestSlaveNode_IsFeatureEnabled(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.False(m.IsFeatureEnabled("aria2"))
a.False(m.IsFeatureEnabled("random"))
m.Model.Aria2Enabled = true
a.True(m.IsFeatureEnabled("aria2"))
}
func TestSlaveNode_Ping(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
// master return error code
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.Error(err)
a.Nil(res)
a.Equal(1, err.(serializer.AppError).Code)
}
// return unexpected json
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.Error(err)
a.Nil(res)
}
// return success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.NoError(err)
a.NotNil(res)
}
}
func TestSlaveNode_GetAria2Instance(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.NotNil(m.GetAria2Instance())
m.Model.Aria2Enabled = true
a.NotNil(m.GetAria2Instance())
a.NotNil(m.GetAria2Instance())
}
func TestSlaveNode_StartPingLoop(t *testing.T) {
callbackCount := 0
finishedChan := make(chan struct{})
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m := &SlaveNode{
Active: true,
Model: &model.Node{},
callback: func(b bool, u uint) {
callbackCount++
if callbackCount == 2 {
close(finishedChan)
}
if callbackCount == 1 {
mockRequest.AssertExpectations(t)
mockRequest = requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
},
})
}
},
}
cache.Set("setting_slave_ping_interval", "0", 0)
cache.Set("setting_slave_recover_interval", "0", 0)
cache.Set("setting_slave_node_retry", "1", 0)
m.caller.Client = &mockRequest
go func() {
select {
case <-finishedChan:
m.Kill()
}
}()
m.StartPingLoop()
mockRequest.AssertExpectations(t)
}
func TestSlaveNode_AuthInstance(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.NotNil(m.MasterAuthInstance())
a.NotNil(m.SlaveAuthInstance())
}
func TestSlaveNode_ChangeStatus(t *testing.T) {
a := assert.New(t)
isActive := false
m := &SlaveNode{
Model: &model.Node{},
callback: func(b bool, u uint) {
isActive = b
},
}
a.NotPanics(func() {
m.changeStatus(false)
})
m.changeStatus(true)
a.True(isActive)
}
func getTestRPCNodeSlave() *SlaveNode {
m := &SlaveNode{
Model: &model.Node{},
}
m.caller.parent = m
return m
}
func TestSlaveCaller_CreateTask(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Empty(res)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Empty(res)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Equal("res", res)
a.NoError(err)
}
}
func TestSlaveCaller_Status(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.NoError(err)
}
}
func TestSlaveCaller_Cancel(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.NoError(err)
}
}
func TestSlaveCaller_Select(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
m.caller.Init()
m.caller.GetConfig()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.NoError(err)
}
}
func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
m.caller.Init()
m.caller.GetConfig()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.NoError(err)
}
}
func TestRemoteCallback(t *testing.T) {
asserts := assert.New(t)
// 回调成功
{
clientMock := requestmock.RequestMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.NoError(resp)
clientMock.AssertExpectations(t)
}
// 服务端返回业务错误
{
clientMock := requestmock.RequestMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.EqualValues(401, resp.(serializer.AppError).Code)
clientMock.AssertExpectations(t)
}
// 无法解析回调响应
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// HTTP状态码非200
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 404,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// 无法发起回调
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
}