Feat: aria2 download and transfer in slave node (#1040)

* Feat: retrieve nodes from data table

* Feat: master node ping slave node in REST API

* Feat: master send scheduled ping request

* Feat: inactive nodes recover loop

* Modify: remove database operations from aria2 RPC caller implementation

* Feat: init aria2 client in master node

* Feat: Round Robin load balancer

* Feat: create and monitor aria2 task in master node

* Feat: salve receive and handle heartbeat

* Fix: Node ID will be 0 in download record generated in older version

* Feat: sign request headers with all `X-` prefix

* Feat: API call to slave node will carry meta data in headers

* Feat: call slave aria2 rpc method from master

* Feat: get slave aria2 task status
Feat: encode slave response data using gob

* Feat: aria2 callback to master node / cancel or select task to slave node

* Fix: use dummy aria2 client when caller initialize failed in master node

* Feat: slave aria2 status event callback / salve RPC auth

* Feat: prototype for slave driven filesystem

* Feat: retry for init aria2 client in master node

* Feat: init request client with global options

* Feat: slave receive async task from master

* Fix: competition write in request header

* Refactor: dependency initialize order

* Feat: generic message queue implementation

* Feat: message queue implementation

* Feat: master waiting slave transfer result

* Feat: slave transfer file in stateless policy

* Feat: slave transfer file in slave policy

* Feat: slave transfer file in local policy

* Feat: slave transfer file in OneDrive policy

* Fix: failed to initialize update checker http client

* Feat: list slave nodes for dashboard

* Feat: test aria2 rpc connection in slave

* Feat: add and save node

* Feat: add and delete node in node pool

* Fix: temp file cannot be removed when aria2 task fails

* Fix: delete node in admin panel

* Feat: edit node and get node info

* Modify: delete unused settings
This commit is contained in:
AaronLiu
2021-10-31 09:41:56 +08:00
committed by GitHub
parent a3b4a22dbc
commit 056de22edb
74 changed files with 3647 additions and 715 deletions

View File

@@ -1,169 +1,65 @@
package aria2
import (
"encoding/json"
"context"
"fmt"
"net/url"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
)
// Instance 默认使用的Aria2处理实例
var Instance Aria2 = &DummyAria2{}
var Instance common.Aria2 = &common.DummyAria2{}
// LB 获取 Aria2 节点的负载均衡器
var LB balancer.Balancer
// Lock Instance的读写锁
var Lock sync.RWMutex
// EventNotifier 任务状态更新通知处理
var EventNotifier = &Notifier{}
// Aria2 离线下载处理接口
type Aria2 interface {
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) error
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error {
return ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
// GetLoadBalancer 返回供Aria2使用的负载均衡
func GetLoadBalancer() balancer.Balancer {
Lock.RLock()
defer Lock.RUnlock()
return LB
}
// Init 初始化
func Init(isReload bool) {
Lock.Lock()
defer Lock.Unlock()
// 关闭上个初始连接
if previousClient, ok := Instance.(*RPCService); ok {
if previousClient.Caller != nil {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.Caller.Close()
}
}
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
timeout := model.GetIntSetting("aria2_call_timeout", 5)
if options["aria2_rpcurl"] == "" {
Instance = &DummyAria2{}
return
}
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
client := &RPCService{}
// 解析RPC服务地址
server, err := url.Parse(options["aria2_rpcurl"])
if err != nil {
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
Instance = &DummyAria2{}
return
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
if err != nil {
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
Instance = &DummyAria2{}
return
}
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
Instance = &DummyAria2{}
return
}
Instance = client
LB = balancer.NewBalancer("RoundRobin")
Lock.Unlock()
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
NewMonitor(&unfinished[i])
monitor.NewMonitor(&unfinished[i])
}
}
}
// getStatus 将给定的状态字符串转换为状态标识数字
func getStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
// 解析RPC服务地址
rpcServer, err := url.Parse(server)
if err != nil {
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
}
rpcServer.Path = "/jsonrpc"
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
if err != nil {
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
}
return caller.GetVersion()
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
@@ -37,7 +38,7 @@ func TestDummyAria2(t *testing.T) {
}
func TestInit(t *testing.T) {
MAX_RETRY = 0
monitor.MAX_RETRY = 0
asserts := assert.New(t)
cache.Set("setting_aria2_token", "1", 0)
cache.Set("setting_aria2_call_timeout", "5", 0)
@@ -81,11 +82,11 @@ func TestInit(t *testing.T) {
func TestGetStatus(t *testing.T) {
asserts := assert.New(t)
asserts.Equal(4, getStatus("complete"))
asserts.Equal(1, getStatus("active"))
asserts.Equal(0, getStatus("waiting"))
asserts.Equal(2, getStatus("paused"))
asserts.Equal(3, getStatus("error"))
asserts.Equal(5, getStatus("removed"))
asserts.Equal(6, getStatus("?"))
asserts.Equal(4, GetStatus("complete"))
asserts.Equal(1, GetStatus("active"))
asserts.Equal(0, GetStatus("waiting"))
asserts.Equal(2, GetStatus("paused"))
asserts.Equal(3, GetStatus("error"))
asserts.Equal(5, GetStatus("removed"))
asserts.Equal(6, GetStatus("?"))
}

View File

@@ -9,6 +9,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
@@ -33,7 +34,7 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s
Options: options,
}
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
EventNotifier)
mq.GlobalMQ)
client.Caller = caller
return err
}
@@ -85,7 +86,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error {
}
// CreateTask 创建新任务
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error {
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
// 生成存储路径
path := filepath.Join(
model.GetSettingByName("aria2_temp_path"),
@@ -106,18 +107,8 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
gid, err := client.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return err
return "", err
}
// 保存到数据库
task.GID = gid
_, err = task.Create()
if err != nil {
return err
}
// 创建任务监控
NewMonitor(task)
return nil
return gid, nil
}

114
pkg/aria2/common/common.go Normal file
View File

@@ -0,0 +1,114 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Aria2 离线下载处理接口
type Aria2 interface {
// Init 初始化客户端连接
Init() error
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
// 获取离线下载配置
GetConfig() model.Aria2Option
// 删除临时下载文件
DeleteTempFile(*model.Download) error
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
func (instance *DummyAria2) Init() error {
return nil
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
return "", ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
}
// GetConfig 返回空的
func (instance *DummyAria2) GetConfig() model.Aria2Option {
return model.Aria2Option{}
}
// GetConfig 返回空的
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
return ErrNotEnabled
}
// GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
}
}

View File

@@ -1,19 +1,21 @@
package aria2
package monitor
import (
"context"
"encoding/json"
"errors"
"os"
"path/filepath"
"strconv"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
@@ -23,32 +25,34 @@ type Monitor struct {
Task *model.Download
Interval time.Duration
notifier chan StatusEvent
notifier <-chan mq.Message
node cluster.Node
retried int
}
// StatusEvent 状态改变事件
type StatusEvent struct {
GID string
Status int
}
var MAX_RETRY = 10
// NewMonitor 新建上传状态监控
// NewMonitor 新建离线下载状态监控
func NewMonitor(task *model.Download) {
monitor := &Monitor{
Task: task,
Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second,
notifier: make(chan StatusEvent),
notifier: make(chan mq.Message),
node: cluster.Default.GetNodeByID(task.GetNodeID()),
}
if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop()
monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0)
} else {
monitor.setErrorStatus(errors.New("节点不可用"))
}
go monitor.Loop()
EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
}
// Loop 开启监控循环
func (monitor *Monitor) Loop() {
defer EventNotifier.Unsubscribe(monitor.Task.GID)
defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier)
// 首次循环立即更新
interval := time.Duration(0)
@@ -70,9 +74,7 @@ func (monitor *Monitor) Loop() {
// Update 更新状态,返回值表示是否退出监控
func (monitor *Monitor) Update() bool {
Lock.RLock()
status, err := Instance.Status(monitor.Task)
Lock.RUnlock()
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
if err != nil {
monitor.retried++
@@ -102,6 +104,7 @@ func (monitor *Monitor) Update() bool {
if err := monitor.UpdateTaskInfo(status); err != nil {
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s]", monitor.Task.GID, err)
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
@@ -115,7 +118,7 @@ func (monitor *Monitor) Update() bool {
case "active", "waiting", "paused":
return false
case "removed":
monitor.Task.Status = Canceled
monitor.Task.Status = common.Canceled
monitor.Task.Save()
monitor.RemoveTempFolder()
return true
@@ -130,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid
monitor.Task.Status = getStatus(status.Status)
monitor.Task.Status = common.GetStatus(status.Status)
// 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
@@ -164,9 +167,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
// 文件大小更新后,对文件限制等进行校验
if err := monitor.ValidateFile(); err != nil {
// 验证失败时取消任务
Lock.RLock()
Instance.Cancel(monitor.Task)
Lock.RUnlock()
monitor.node.GetAria2Instance().Cancel(monitor.Task)
return err
}
}
@@ -179,7 +180,7 @@ func (monitor *Monitor) ValidateFile() error {
// 找到任务创建者
user := monitor.Task.GetOwner()
if user == nil {
return ErrUserNotFound
return common.ErrUserNotFound
}
// 创建文件系统
@@ -230,28 +231,31 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
// RemoveTempFolder 清理下载临时目录
func (monitor *Monitor) RemoveTempFolder() {
err := os.RemoveAll(monitor.Task.Parent)
if err != nil {
util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err)
}
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
}
// Complete 完成下载,返回是否中断监控
func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
// 创建中转任务
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
if monitor.Task.StatusInfo.Files[i].Selected == "true" {
file = append(file, monitor.Task.StatusInfo.Files[i].Path)
fileInfo := monitor.Task.StatusInfo.Files[i]
if fileInfo.Selected == "true" {
file = append(file, fileInfo.Path)
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
sizes[fileInfo.Path] = size
}
}
job, err := task.NewTransferTask(
monitor.Task.UserID,
file,
monitor.Task.Dst,
monitor.Task.Parent,
true,
monitor.node.ID(),
sizes,
)
if err != nil {
monitor.setErrorStatus(err)
@@ -269,7 +273,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
}
func (monitor *Monitor) setErrorStatus(err error) {
monitor.Task.Status = Error
monitor.Task.Status = common.Error
monitor.Task.Error = err.Error()
monitor.Task.Save()
}

View File

@@ -1,4 +1,4 @@
package aria2
package monitor
import (
"errors"
@@ -7,6 +7,8 @@ import (
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
@@ -44,13 +46,13 @@ func (m InstanceMock) Select(task *model.Download, files []int) error {
func TestNewMonitor(t *testing.T) {
asserts := assert.New(t)
NewMonitor(&model.Download{GID: "gid"})
_, ok := EventNotifier.Subscribes.Load("gid")
_, ok := common.EventNotifier.Subscribes.Load("gid")
asserts.True(ok)
}
func TestMonitor_Loop(t *testing.T) {
asserts := assert.New(t)
notifier := make(chan StatusEvent)
notifier := make(chan common.StatusEvent)
MAX_RETRY = 0
monitor := &Monitor{
Task: &model.Download{GID: "gid"},
@@ -76,10 +78,10 @@ func TestMonitor_Update(t *testing.T) {
{
MAX_RETRY = 1
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
file.Close()
Instance = testInstance
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.True(monitor.Update())
testInstance.AssertExpectations(t)
@@ -89,16 +91,16 @@ func TestMonitor_Update(t *testing.T) {
// 磁力链下载重定向
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"1"},
}, nil)
monitor.Task.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
asserts.EqualValues("1", monitor.Task.GID)
}
@@ -106,82 +108,82 @@ func TestMonitor_Update(t *testing.T) {
// 无法更新任务信息
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil)
monitor.Task.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
Instance = testInstance
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
aria2.mock.ExpectRollback()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回未知状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回被取消状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回活跃状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回错误状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回完成
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
aria2.Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
}
@@ -198,34 +200,34 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) {
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
aria2.mock.ExpectRollback()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err)
}
// 更新成功,无需校验
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.NoError(err)
}
// 更新成功,大小改变,需要校验,校验失败
{
testInstance := new(InstanceMock)
testInstance.On("Cancel", testMock.Anything).Return(nil)
Instance = testInstance
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
testInstance.On("SlaveCancel", testMock.Anything).Return(nil)
aria2.Instance = testInstance
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
asserts.Error(err)
testInstance.AssertExpectations(t)
}
@@ -308,17 +310,17 @@ func TestMonitor_Complete(t *testing.T) {
}
cache.Set("setting_max_worker_num", "1", 0)
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
task.Init()
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
asserts.True(monitor.Complete(rpc.StatusInfo{}))
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(aria2.mock.ExpectationsWereMet())
}

View File

@@ -1,64 +0,0 @@
package aria2
import (
"sync"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
)
// Notifier aria2实践通知处理
type Notifier struct {
Subscribes sync.Map
}
// Subscribe 订阅事件通知
func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) {
notifier.Subscribes.Store(gid, target)
}
// Unsubscribe 取消订阅事件通知
func (notifier *Notifier) Unsubscribe(gid string) {
notifier.Subscribes.Delete(gid)
}
// Notify 发送通知
func (notifier *Notifier) Notify(events []rpc.Event, status int) {
for _, event := range events {
if target, ok := notifier.Subscribes.Load(event.Gid); ok {
target.(chan StatusEvent) <- StatusEvent{
GID: event.Gid,
Status: status,
}
}
}
}
// OnDownloadStart 下载开始
func (notifier *Notifier) OnDownloadStart(events []rpc.Event) {
notifier.Notify(events, Downloading)
}
// OnDownloadPause 下载暂停
func (notifier *Notifier) OnDownloadPause(events []rpc.Event) {
notifier.Notify(events, Paused)
}
// OnDownloadStop 下载停止
func (notifier *Notifier) OnDownloadStop(events []rpc.Event) {
notifier.Notify(events, Canceled)
}
// OnDownloadComplete 下载完成
func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) {
notifier.Notify(events, Complete)
}
// OnDownloadError 下载出错
func (notifier *Notifier) OnDownloadError(events []rpc.Event) {
notifier.Notify(events, Error)
}
// OnBtDownloadComplete BT下载完成
func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) {
notifier.Notify(events, Complete)
}