Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
@@ -1,67 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
)
|
||||
|
||||
// Instance 默认使用的Aria2处理实例
|
||||
var Instance common.Aria2 = &common.DummyAria2{}
|
||||
|
||||
// LB 获取 Aria2 节点的负载均衡器
|
||||
var LB balancer.Balancer
|
||||
|
||||
// Lock Instance的读写锁
|
||||
var Lock sync.RWMutex
|
||||
|
||||
// GetLoadBalancer 返回供Aria2使用的负载均衡器
|
||||
func GetLoadBalancer() balancer.Balancer {
|
||||
Lock.RLock()
|
||||
defer Lock.RUnlock()
|
||||
return LB
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
|
||||
Lock.Lock()
|
||||
LB = balancer.NewBalancer("RoundRobin")
|
||||
Lock.Unlock()
|
||||
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding)
|
||||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
monitor.NewMonitor(&unfinished[i], pool, mqClient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
|
||||
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
|
||||
// 解析RPC服务地址
|
||||
rpcServer, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
|
||||
}
|
||||
|
||||
rpcServer.Path = "/jsonrpc"
|
||||
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
|
||||
if err != nil {
|
||||
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
|
||||
}
|
||||
|
||||
return caller.GetVersion()
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
|
||||
mockQueue := mq.NewMQ()
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
Init(false, mockPool, mockQueue)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTestRPCConnection(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
// url not legal
|
||||
{
|
||||
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
|
||||
a.Error(err)
|
||||
a.Empty(res.Version)
|
||||
}
|
||||
|
||||
// rpc failed
|
||||
{
|
||||
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
|
||||
a.Error(err)
|
||||
a.Empty(res.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoadBalancer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.NotPanics(func() {
|
||||
GetLoadBalancer()
|
||||
})
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Aria2 离线下载处理接口
|
||||
type Aria2 interface {
|
||||
// Init 初始化客户端连接
|
||||
Init() error
|
||||
// CreateTask 创建新的任务
|
||||
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
|
||||
// 返回状态信息
|
||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||
// 取消任务
|
||||
Cancel(task *model.Download) error
|
||||
// 选择要下载的文件
|
||||
Select(task *model.Download, files []int) error
|
||||
// 获取离线下载配置
|
||||
GetConfig() model.Aria2Option
|
||||
// 删除临时下载文件
|
||||
DeleteTempFile(*model.Download) error
|
||||
}
|
||||
|
||||
const (
|
||||
// URLTask 从URL添加的任务
|
||||
URLTask = iota
|
||||
// TorrentTask 种子任务
|
||||
TorrentTask
|
||||
)
|
||||
|
||||
const (
|
||||
// Ready 准备就绪
|
||||
Ready = iota
|
||||
// Downloading 下载中
|
||||
Downloading
|
||||
// Paused 暂停中
|
||||
Paused
|
||||
// Error 出错
|
||||
Error
|
||||
// Complete 完成
|
||||
Complete
|
||||
// Canceled 取消/停止
|
||||
Canceled
|
||||
// Unknown 未知状态
|
||||
Unknown
|
||||
// Seeding 做种中
|
||||
Seeding
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotEnabled 功能未开启错误
|
||||
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil)
|
||||
// ErrUserNotFound 未找到下载任务创建者
|
||||
ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil)
|
||||
)
|
||||
|
||||
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
||||
type DummyAria2 struct {
|
||||
}
|
||||
|
||||
func (instance *DummyAria2) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
|
||||
return "", ErrNotEnabled
|
||||
}
|
||||
|
||||
// Status 返回未开启错误
|
||||
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
return rpc.StatusInfo{}, ErrNotEnabled
|
||||
}
|
||||
|
||||
// Cancel 返回未开启错误
|
||||
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// Select 返回未开启错误
|
||||
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// GetConfig 返回空的
|
||||
func (instance *DummyAria2) GetConfig() model.Aria2Option {
|
||||
return model.Aria2Option{}
|
||||
}
|
||||
|
||||
// GetConfig 返回空的
|
||||
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// GetStatus 将给定的状态字符串转换为状态标识数字
|
||||
func GetStatus(status rpc.StatusInfo) int {
|
||||
switch status.Status {
|
||||
case "complete":
|
||||
return Complete
|
||||
case "active":
|
||||
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
|
||||
return Seeding
|
||||
}
|
||||
return Downloading
|
||||
case "waiting":
|
||||
return Ready
|
||||
case "paused":
|
||||
return Paused
|
||||
case "error":
|
||||
return Error
|
||||
case "removed":
|
||||
return Canceled
|
||||
default:
|
||||
return Unknown
|
||||
}
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDummyAria2(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
d := &DummyAria2{}
|
||||
|
||||
a.NoError(d.Init())
|
||||
|
||||
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
|
||||
_, err = d.Status(&model.Download{})
|
||||
a.Error(err)
|
||||
|
||||
err = d.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
|
||||
err = d.Select(&model.Download{}, []int{})
|
||||
a.Error(err)
|
||||
|
||||
configRes := d.GetConfig()
|
||||
a.NotNil(configRes)
|
||||
|
||||
err = d.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
|
||||
BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
|
||||
BitTorrent: rpc.BitTorrentInfo{Mode: "single"},
|
||||
TotalLength: "100", CompletedLength: "50"}), Downloading)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
|
||||
BitTorrent: rpc.BitTorrentInfo{Mode: "multi"},
|
||||
TotalLength: "100", CompletedLength: "100"}), Seeding)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled)
|
||||
a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown)
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Monitor 离线下载状态监控
|
||||
type Monitor struct {
|
||||
Task *model.Download
|
||||
Interval time.Duration
|
||||
|
||||
notifier <-chan mq.Message
|
||||
node cluster.Node
|
||||
retried int
|
||||
}
|
||||
|
||||
var MAX_RETRY = 10
|
||||
|
||||
// NewMonitor 新建离线下载状态监控
|
||||
func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
|
||||
monitor := &Monitor{
|
||||
Task: task,
|
||||
notifier: make(chan mq.Message),
|
||||
node: pool.GetNodeByID(task.GetNodeID()),
|
||||
}
|
||||
|
||||
if monitor.node != nil {
|
||||
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
|
||||
go monitor.Loop(mqClient)
|
||||
|
||||
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
|
||||
} else {
|
||||
monitor.setErrorStatus(errors.New("node not avaliable"))
|
||||
}
|
||||
}
|
||||
|
||||
// Loop 开启监控循环
|
||||
func (monitor *Monitor) Loop(mqClient mq.MQ) {
|
||||
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
|
||||
|
||||
// 首次循环立即更新
|
||||
interval := 50 * time.Millisecond
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-monitor.notifier:
|
||||
if monitor.Update() {
|
||||
return
|
||||
}
|
||||
case <-time.After(interval):
|
||||
interval = monitor.Interval
|
||||
if monitor.Update() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update 更新状态,返回值表示是否退出监控
|
||||
func (monitor *Monitor) Update() bool {
|
||||
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
|
||||
|
||||
if err != nil {
|
||||
monitor.retried++
|
||||
util.Log().Warning("Cannot get status of download task %q: %s", monitor.Task.GID, err)
|
||||
|
||||
// 十次重试后认定为任务失败
|
||||
if monitor.retried > MAX_RETRY {
|
||||
util.Log().Warning("Cannot get status of download task %q,exceed maximum retry threshold: %s",
|
||||
monitor.Task.GID, err)
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
monitor.retried = 0
|
||||
|
||||
// 磁力链下载需要跟随
|
||||
if len(status.FollowedBy) > 0 {
|
||||
util.Log().Debug("Redirected download task from %q to %q.", monitor.Task.GID, status.FollowedBy[0])
|
||||
monitor.Task.GID = status.FollowedBy[0]
|
||||
monitor.Task.Save()
|
||||
return false
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
if err := monitor.UpdateTaskInfo(status); err != nil {
|
||||
util.Log().Warning("Failed to update status of download task %q: %s", monitor.Task.GID, err)
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
util.Log().Debug("Remote download %q status updated to %q.", status.Gid, status.Status)
|
||||
|
||||
switch common.GetStatus(status) {
|
||||
case common.Complete, common.Seeding:
|
||||
return monitor.Complete(task.TaskPoll)
|
||||
case common.Error:
|
||||
return monitor.Error(status)
|
||||
case common.Downloading, common.Ready, common.Paused:
|
||||
return false
|
||||
case common.Canceled:
|
||||
monitor.Task.Status = common.Canceled
|
||||
monitor.Task.Save()
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
default:
|
||||
util.Log().Warning("Download task %q returns unknown status %q.", monitor.Task.GID, status.Status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskInfo 更新数据库中的任务信息
|
||||
func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||
originSize := monitor.Task.TotalSize
|
||||
|
||||
monitor.Task.GID = status.Gid
|
||||
monitor.Task.Status = common.GetStatus(status)
|
||||
|
||||
// 文件大小、已下载大小
|
||||
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
|
||||
if err != nil {
|
||||
total = 0
|
||||
}
|
||||
downloaded, err := strconv.ParseUint(status.CompletedLength, 10, 64)
|
||||
if err != nil {
|
||||
downloaded = 0
|
||||
}
|
||||
monitor.Task.TotalSize = total
|
||||
monitor.Task.DownloadedSize = downloaded
|
||||
monitor.Task.GID = status.Gid
|
||||
monitor.Task.Parent = status.Dir
|
||||
|
||||
// 下载速度
|
||||
speed, err := strconv.Atoi(status.DownloadSpeed)
|
||||
if err != nil {
|
||||
speed = 0
|
||||
}
|
||||
|
||||
monitor.Task.Speed = speed
|
||||
attrs, _ := json.Marshal(status)
|
||||
monitor.Task.Attrs = string(attrs)
|
||||
|
||||
if err := monitor.Task.Save(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if originSize != monitor.Task.TotalSize {
|
||||
// 文件大小更新后,对文件限制等进行校验
|
||||
if err := monitor.ValidateFile(); err != nil {
|
||||
// 验证失败时取消任务
|
||||
monitor.node.GetAria2Instance().Cancel(monitor.Task)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateFile 上传过程中校验文件大小、文件名
|
||||
func (monitor *Monitor) ValidateFile() error {
|
||||
// 找到任务创建者
|
||||
user := monitor.Task.GetOwner()
|
||||
if user == nil {
|
||||
return common.ErrUserNotFound
|
||||
}
|
||||
|
||||
// 创建文件系统
|
||||
fs, err := filesystem.NewFileSystem(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fs.Recycle()
|
||||
|
||||
// 创建上下文环境
|
||||
file := &fsctx.FileStream{
|
||||
Size: monitor.Task.TotalSize,
|
||||
}
|
||||
|
||||
// 验证用户容量
|
||||
if err := filesystem.HookValidateCapacity(context.Background(), fs, file); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证每个文件
|
||||
for _, fileInfo := range monitor.Task.StatusInfo.Files {
|
||||
if fileInfo.Selected == "true" {
|
||||
// 创建上下文环境
|
||||
fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
|
||||
file := &fsctx.FileStream{
|
||||
Size: fileSize,
|
||||
Name: filepath.Base(fileInfo.Path),
|
||||
}
|
||||
if err := filesystem.HookValidateFile(context.Background(), fs, file); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error 任务下载出错处理,返回是否中断监控
|
||||
func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
|
||||
monitor.setErrorStatus(errors.New(status.ErrorMessage))
|
||||
|
||||
// 清理临时文件
|
||||
monitor.RemoveTempFolder()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveTempFolder 清理下载临时目录
|
||||
func (monitor *Monitor) RemoveTempFolder() {
|
||||
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
|
||||
}
|
||||
|
||||
// Complete 完成下载,返回是否中断监控
|
||||
func (monitor *Monitor) Complete(pool task.Pool) bool {
|
||||
// 未开始转存,提交转存任务
|
||||
if monitor.Task.TaskID == 0 {
|
||||
return monitor.transfer(pool)
|
||||
}
|
||||
|
||||
// 做种完成
|
||||
if common.GetStatus(monitor.Task.StatusInfo) == common.Complete {
|
||||
transferTask, err := model.GetTasksByID(monitor.Task.TaskID)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 转存完成,回收下载目录
|
||||
if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error {
|
||||
job, err := task.NewRecycleTask(monitor.Task)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 提交回收任务
|
||||
pool.Submit(job)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (monitor *Monitor) transfer(pool task.Pool) bool {
|
||||
// 创建中转任务
|
||||
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
|
||||
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
|
||||
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
|
||||
fileInfo := monitor.Task.StatusInfo.Files[i]
|
||||
if fileInfo.Selected == "true" {
|
||||
file = append(file, fileInfo.Path)
|
||||
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
|
||||
sizes[fileInfo.Path] = size
|
||||
}
|
||||
}
|
||||
|
||||
job, err := task.NewTransferTask(
|
||||
monitor.Task.UserID,
|
||||
file,
|
||||
monitor.Task.Dst,
|
||||
monitor.Task.Parent,
|
||||
true,
|
||||
monitor.node.ID(),
|
||||
sizes,
|
||||
)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 提交中转任务
|
||||
pool.Submit(job)
|
||||
|
||||
// 更新任务ID
|
||||
monitor.Task.TaskID = job.Model().ID
|
||||
monitor.Task.Save()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (monitor *Monitor) setErrorStatus(err error) {
|
||||
monitor.Task.Status = common.Error
|
||||
monitor.Task.Error = err.Error()
|
||||
monitor.Task.Save()
|
||||
}
|
||||
@@ -1,447 +0,0 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestNewMonitor(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockMQ := mq.NewMQ()
|
||||
|
||||
// node not available
|
||||
{
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", uint(1)).Return(nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
task := &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
}
|
||||
NewMonitor(task, mockPool, mockMQ)
|
||||
mockPool.AssertExpectations(t)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(task.Error)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
|
||||
|
||||
task := &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
}
|
||||
NewMonitor(task, mockPool, mockMQ)
|
||||
mockNode.AssertExpectations(t)
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMonitor_Loop(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockMQ := mq.NewMQ()
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
retried: MAX_RETRY,
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
notifier: mockMQ.Subscribe("test", 1),
|
||||
}
|
||||
|
||||
// into interval loop
|
||||
{
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
m.Loop(mockMQ)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
// into notifier loop
|
||||
{
|
||||
m.Task.Error = ""
|
||||
mockMQ.Publish("test", mq.Message{})
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
m.Loop(mockMQ)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
for i := 0; i < MAX_RETRY; i++ {
|
||||
a.False(m.Update())
|
||||
}
|
||||
|
||||
mockNode.AssertExpectations(t)
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateMagentoFollow(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
FollowedBy: []string{"next"},
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.Equal("next", m.Task.GID)
|
||||
mockAria2.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateCompleted(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "complete",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
mockNode.On("ID").Return(uint(1))
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateError(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "error",
|
||||
ErrorMessage: "error",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateActive(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "active",
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateRemoved(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "removed",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.Equal(common.Canceled, m.Task.Status)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateUnknown(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "unknown",
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
status := rpc.StatusInfo{
|
||||
Status: "completed",
|
||||
TotalLength: "100",
|
||||
CompletedLength: "50",
|
||||
DownloadSpeed: "20",
|
||||
}
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
err := m.UpdateTaskInfo(status)
|
||||
a.Error(err)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_ValidateFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
TotalSize: 100,
|
||||
},
|
||||
}
|
||||
|
||||
// failed to create filesystem
|
||||
{
|
||||
m.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "random",
|
||||
},
|
||||
}
|
||||
a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile())
|
||||
}
|
||||
|
||||
// User capacity not enough
|
||||
{
|
||||
m.Task.User = &model.User{
|
||||
Group: model.Group{
|
||||
MaxStorage: 99,
|
||||
},
|
||||
Policy: model.Policy{
|
||||
Type: "local",
|
||||
},
|
||||
}
|
||||
a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile())
|
||||
}
|
||||
|
||||
// single file too big
|
||||
{
|
||||
m.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Length: "100",
|
||||
Selected: "true",
|
||||
},
|
||||
}
|
||||
m.Task.User = &model.User{
|
||||
Group: model.Group{
|
||||
MaxStorage: 100,
|
||||
},
|
||||
Policy: model.Policy{
|
||||
Type: "local",
|
||||
MaxSize: 99,
|
||||
},
|
||||
}
|
||||
a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile())
|
||||
}
|
||||
|
||||
// all pass
|
||||
{
|
||||
m.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Length: "100",
|
||||
Selected: "true",
|
||||
},
|
||||
}
|
||||
m.Task.User = &model.User{
|
||||
Group: model.Group{
|
||||
MaxStorage: 100,
|
||||
},
|
||||
Policy: model.Policy{
|
||||
Type: "local",
|
||||
MaxSize: 100,
|
||||
},
|
||||
}
|
||||
a.NoError(m.ValidateFile())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_Complete(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("ID").Return(uint(1))
|
||||
mockPool := &mocks.TaskPoolMock{}
|
||||
mockPool.On("Submit", testMock.Anything)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
TotalSize: 100,
|
||||
UserID: 9414,
|
||||
},
|
||||
}
|
||||
m.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Length: "100",
|
||||
Selected: "true",
|
||||
},
|
||||
}
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4))
|
||||
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Complete(mockPool))
|
||||
m.Task.StatusInfo.Status = "complete"
|
||||
a.True(m.Complete(mockPool))
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockNode.AssertExpectations(t)
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
161
pkg/auth/auth.go
161
pkg/auth/auth.go
@@ -2,18 +2,18 @@ package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,37 +23,59 @@ var (
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil)
|
||||
)
|
||||
|
||||
const CrHeaderPrefix = "X-Cr-"
|
||||
const (
|
||||
TokenHeaderPrefixCr = "Bearer Cr "
|
||||
)
|
||||
|
||||
// General 通用的认证接口
|
||||
// Deprecated
|
||||
var General Auth
|
||||
|
||||
// Auth 鉴权认证
|
||||
type Auth interface {
|
||||
// 对给定Body进行签名,expires为0表示永不过期
|
||||
Sign(body string, expires int64) string
|
||||
// 对给定Body和Sign进行检查
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
type (
|
||||
// Auth 鉴权认证
|
||||
Auth interface {
|
||||
// 对给定Body进行签名,expires为0表示永不过期
|
||||
Sign(body string, expires int64) string
|
||||
// 对给定Body和Sign进行检查
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
)
|
||||
|
||||
// SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||
// 请求正文、`X-Cr-`开头的header进行签名
|
||||
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||
func SignRequest(ctx context.Context, instance Auth, r *http.Request, expires *time.Time) *http.Request {
|
||||
// 处理有效期
|
||||
expireTime := int64(0)
|
||||
if expires != nil {
|
||||
expireTime = expires.Unix()
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getSignContent(ctx, r), expireTime)
|
||||
|
||||
// 将签名加到请求Header中
|
||||
r.Header["Authorization"] = []string{TokenHeaderPrefixCr + sign}
|
||||
return r
|
||||
}
|
||||
|
||||
// SignRequestDeprecated 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||
// 请求正文、`X-Cr-`开头的header进行签名
|
||||
func SignRequestDeprecated(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||
// 处理有效期
|
||||
if expires > 0 {
|
||||
expires += time.Now().Unix()
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getSignContent(r), expires)
|
||||
sign := instance.Sign(getSignContent(context.Background(), r), expires)
|
||||
|
||||
// 将签名加到请求Header中
|
||||
r.Header["Authorization"] = []string{"Bearer " + sign}
|
||||
r.Header["Authorization"] = []string{TokenHeaderPrefixCr + sign}
|
||||
return r
|
||||
}
|
||||
|
||||
// CheckRequest 对复杂请求进行签名验证
|
||||
func CheckRequest(instance Auth, r *http.Request) error {
|
||||
func CheckRequest(ctx context.Context, instance Auth, r *http.Request) error {
|
||||
var (
|
||||
sign []string
|
||||
ok bool
|
||||
@@ -61,41 +83,71 @@ func CheckRequest(instance Auth, r *http.Request) error {
|
||||
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||
return ErrAuthHeaderMissing
|
||||
}
|
||||
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||
sign[0] = strings.TrimPrefix(sign[0], TokenHeaderPrefixCr)
|
||||
|
||||
return instance.Check(getSignContent(ctx, r), sign[0])
|
||||
}
|
||||
|
||||
func isUploadDataRequest(r *http.Request) bool {
|
||||
return strings.Contains(r.URL.Path, constants.APIPrefix+"/slave/upload/") && r.Method != http.MethodPut
|
||||
|
||||
return instance.Check(getSignContent(r), sign[0])
|
||||
}
|
||||
|
||||
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API,
|
||||
// 则不对正文签名。返回待签名/验证的字符串
|
||||
func getSignContent(r *http.Request) (rawSignString string) {
|
||||
func getSignContent(ctx context.Context, r *http.Request) (rawSignString string) {
|
||||
// 读取所有body正文
|
||||
var body = []byte{}
|
||||
if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") {
|
||||
if !isUploadDataRequest(r) {
|
||||
if r.Body != nil {
|
||||
body, _ = ioutil.ReadAll(r.Body)
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
}
|
||||
|
||||
// 决定要签名的header
|
||||
var signedHeader []string
|
||||
for k, _ := range r.Header {
|
||||
if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" {
|
||||
if strings.HasPrefix(k, constants.CrHeaderPrefix) && k != constants.CrHeaderPrefix+"Filename" {
|
||||
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
|
||||
}
|
||||
}
|
||||
sort.Strings(signedHeader)
|
||||
|
||||
// 读取所有待签名Header
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
|
||||
rawSignString = serializer.NewRequestSignString(getUrlSignContent(ctx, r.URL), strings.Join(signedHeader, "&"), string(body))
|
||||
|
||||
return rawSignString
|
||||
}
|
||||
|
||||
// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||
func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
// SignURI 对URI进行签名
|
||||
func SignURI(ctx context.Context, instance Auth, uri string, expires *time.Time) (*url.URL, error) {
|
||||
// 处理有效期
|
||||
expireTime := int64(0)
|
||||
if expires != nil {
|
||||
expireTime = expires.Unix()
|
||||
}
|
||||
|
||||
base, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getUrlSignContent(ctx, base), expireTime)
|
||||
|
||||
// 将签名加到URI中
|
||||
queries := base.Query()
|
||||
queries.Set("sign", sign)
|
||||
base.RawQuery = queries.Encode()
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// SignURIDeprecated 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||
// Deprecated
|
||||
func SignURIDeprecated(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
// 处理有效期
|
||||
if expires != 0 {
|
||||
expires += time.Now().Unix()
|
||||
@@ -118,28 +170,55 @@ func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
}
|
||||
|
||||
// CheckURI 对URI进行鉴权
|
||||
func CheckURI(instance Auth, url *url.URL) error {
|
||||
func CheckURI(ctx context.Context, instance Auth, url *url.URL) error {
|
||||
//获取待验证的签名正文
|
||||
queries := url.Query()
|
||||
sign := queries.Get("sign")
|
||||
queries.Del("sign")
|
||||
url.RawQuery = queries.Encode()
|
||||
|
||||
return instance.Check(url.Path, sign)
|
||||
return instance.Check(getUrlSignContent(ctx, url), sign)
|
||||
}
|
||||
|
||||
// Init 初始化通用鉴权器
|
||||
func Init() {
|
||||
var secretKey string
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
secretKey = model.GetSettingByName("secret_key")
|
||||
} else {
|
||||
secretKey = conf.SlaveConfig.Secret
|
||||
if secretKey == "" {
|
||||
util.Log().Panic("SlaveSecret is not set, please specify it in config file.")
|
||||
func RedactSensitiveValues(errorMessage string) string {
|
||||
// Regular expression to match URLs
|
||||
urlRegex := regexp.MustCompile(`https?://[^\s]+`)
|
||||
// Find all URLs in the error message
|
||||
urls := urlRegex.FindAllString(errorMessage, -1)
|
||||
|
||||
for _, urlStr := range urls {
|
||||
// Parse the URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the query parameters
|
||||
queryParams := parsedURL.Query()
|
||||
|
||||
// Redact the 'sign' parameter if it exists
|
||||
if _, exists := queryParams["sign"]; exists {
|
||||
queryParams.Set("sign", "REDACTED")
|
||||
parsedURL.RawQuery = queryParams.Encode()
|
||||
}
|
||||
|
||||
// Replace the original URL with the redacted one in the error message
|
||||
errorMessage = strings.Replace(errorMessage, urlStr, parsedURL.String(), -1)
|
||||
}
|
||||
General = HMACAuth{
|
||||
SecretKey: []byte(secretKey),
|
||||
}
|
||||
|
||||
return errorMessage
|
||||
}
|
||||
|
||||
func getUrlSignContent(ctx context.Context, url *url.URL) string {
|
||||
// host := url.Host
|
||||
// if host == "" {
|
||||
// reqInfo := requestinfo.RequestInfoFromContext(ctx)
|
||||
// if reqInfo != nil {
|
||||
// host = reqInfo.Host
|
||||
// }
|
||||
// }
|
||||
// host = strings.TrimSuffix(host, "/")
|
||||
// // remove port if it exists
|
||||
// host = strings.Split(host, ":")[0]
|
||||
return url.Path
|
||||
}
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSignURI(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 成功
|
||||
{
|
||||
sign, err := SignURI(General, "/api/v3/something?id=1", 0)
|
||||
asserts.NoError(err)
|
||||
queries := sign.Query()
|
||||
asserts.Equal("1", queries.Get("id"))
|
||||
asserts.NotEmpty(queries.Get("sign"))
|
||||
}
|
||||
|
||||
// URI解码失败
|
||||
{
|
||||
sign, err := SignURI(General, "://dg.;'f]gh./'", 0)
|
||||
asserts.Error(err)
|
||||
asserts.Nil(sign)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURI(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 成功
|
||||
{
|
||||
sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10)
|
||||
asserts.NoError(err)
|
||||
asserts.NoError(CheckURI(General, sign))
|
||||
}
|
||||
|
||||
// 过期
|
||||
{
|
||||
sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1)
|
||||
asserts.NoError(err)
|
||||
asserts.Error(CheckURI(General, sign))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignRequest(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 非上传请求
|
||||
{
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body."))
|
||||
asserts.NoError(err)
|
||||
req = SignRequest(General, req, 0)
|
||||
asserts.NotEmpty(req.Header["Authorization"])
|
||||
}
|
||||
|
||||
// 上传请求
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/slave/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
|
||||
req = SignRequest(General, req, 10)
|
||||
asserts.NotEmpty(req.Header["Authorization"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRequest(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 缺少请求头
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.Error(err)
|
||||
asserts.Equal(ErrAuthHeaderMissing, err)
|
||||
}
|
||||
|
||||
// 非上传请求 验证成功
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req = SignRequest(General, req, 0)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
// 上传请求 验证成功
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
|
||||
req = SignRequest(General, req, 0)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
// 非上传请求 失败
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req = SignRequest(General, req, 0)
|
||||
req.Body = ioutil.NopCloser(strings.NewReader("2333"))
|
||||
err = CheckRequest(General, req)
|
||||
asserts.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
)
|
||||
|
||||
// HMACAuth HMAC算法鉴权
|
||||
@@ -39,7 +41,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
||||
// 验证是否过期
|
||||
expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64)
|
||||
if err != nil {
|
||||
return ErrAuthFailed.WithError(err)
|
||||
return serializer.NewError(serializer.CodeInvalidSign, "sign expired", nil)
|
||||
}
|
||||
// 如果签名过期
|
||||
if expires < time.Now().Unix() && expires != 0 {
|
||||
@@ -48,7 +50,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
||||
|
||||
// 验证签名
|
||||
if auth.Sign(body, expires) != sign {
|
||||
return ErrAuthFailed
|
||||
return serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// 设置gin为测试模式
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 初始化sqlmock
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
|
||||
mockDB, _ := gorm.Open("mysql", db)
|
||||
model.DB = mockDB
|
||||
defer db.Close()
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestHMACAuth_Sign(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth := HMACAuth{
|
||||
SecretKey: []byte(util.RandStringRunes(256)),
|
||||
}
|
||||
|
||||
asserts.NotEmpty(auth.Sign("content", 0))
|
||||
}
|
||||
|
||||
func TestHMACAuth_Check(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth := HMACAuth{
|
||||
SecretKey: []byte(util.RandStringRunes(256)),
|
||||
}
|
||||
|
||||
// 正常,永不过期
|
||||
{
|
||||
sign := auth.Sign("content", 0)
|
||||
asserts.NoError(auth.Check("content", sign))
|
||||
}
|
||||
|
||||
// 过期
|
||||
{
|
||||
sign := auth.Sign("content", 1)
|
||||
asserts.Error(auth.Check("content", sign))
|
||||
}
|
||||
|
||||
// 签名格式错误
|
||||
{
|
||||
sign := auth.Sign("content", 1)
|
||||
asserts.Error(auth.Check("content", sign+":"))
|
||||
}
|
||||
|
||||
// 过期日期格式错误
|
||||
{
|
||||
asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed"))
|
||||
}
|
||||
|
||||
// 签名有误
|
||||
{
|
||||
asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
|
||||
Init()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
|
||||
// slave模式
|
||||
conf.SystemConfig.Mode = "slave"
|
||||
asserts.Panics(func() {
|
||||
Init()
|
||||
})
|
||||
}
|
||||
200
pkg/auth/jwt.go
Normal file
200
pkg/auth/jwt.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type TokenAuth interface {
|
||||
// Issue issues a new pair of credentials for the given user.
|
||||
Issue(ctx context.Context, u *ent.User) (*Token, error)
|
||||
// VerifyAndRetrieveUser verifies the given token and inject the user into current context.
|
||||
// Returns if upper caller should continue process other session provider.
|
||||
VerifyAndRetrieveUser(c *gin.Context) (bool, error)
|
||||
// Refresh refreshes the given refresh token and returns a new pair of credentials.
|
||||
Refresh(ctx context.Context, refreshToken string) (*Token, error)
|
||||
}
|
||||
|
||||
// Token stores token pair for authentication
|
||||
type Token struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccessExpires time.Time `json:"access_expires"`
|
||||
RefreshExpires time.Time `json:"refresh_expires"`
|
||||
|
||||
UID int `json:"-"`
|
||||
}
|
||||
|
||||
type (
|
||||
TokenType string
|
||||
TokenIDContextKey struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
TokenTypeAccess = TokenType("access")
|
||||
TokenTypeRefresh = TokenType("refresh")
|
||||
|
||||
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
const (
|
||||
AuthorizationHeader = "Authorization"
|
||||
TokenHeaderPrefix = "Bearer "
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
TokenType TokenType `json:"token_type"`
|
||||
jwt.RegisteredClaims
|
||||
StateHash []byte `json:"state_hash,omitempty"`
|
||||
}
|
||||
|
||||
// NewTokenAuth creates a new token based auth provider.
|
||||
func NewTokenAuth(idEncoder hashid.Encoder, s setting.Provider, secret []byte, userClient inventory.UserClient, l logging.Logger) TokenAuth {
|
||||
return &tokenAuth{
|
||||
idEncoder: idEncoder,
|
||||
s: s,
|
||||
secret: secret,
|
||||
userClient: userClient,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
type tokenAuth struct {
|
||||
l logging.Logger
|
||||
idEncoder hashid.Encoder
|
||||
s setting.Provider
|
||||
secret []byte
|
||||
userClient inventory.UserClient
|
||||
}
|
||||
|
||||
func (t *tokenAuth) Refresh(ctx context.Context, refreshToken string) (*Token, error) {
|
||||
token, err := jwt.ParseWithClaims(refreshToken, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return t.secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || claims.TokenType != TokenTypeRefresh {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID)
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
expectedUser, err := t.userClient.GetActiveByID(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
// Check if user changed password or revoked session
|
||||
expectedHash := t.hashUserState(ctx, expectedUser)
|
||||
if !bytes.Equal(claims.StateHash, expectedHash[:]) {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
return t.Issue(ctx, expectedUser)
|
||||
}
|
||||
|
||||
func (t *tokenAuth) VerifyAndRetrieveUser(c *gin.Context) (bool, error) {
|
||||
headerVal := c.GetHeader(AuthorizationHeader)
|
||||
if strings.HasPrefix(headerVal, TokenHeaderPrefixCr) {
|
||||
// This is an HMAC auth header, skip JWT verification
|
||||
return false, nil
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(headerVal, TokenHeaderPrefix)
|
||||
if tokenString == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return t.secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.l.Warning("Failed to parse jwt token: %s", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || claims.TokenType != TokenTypeAccess {
|
||||
return false, serializer.NewError(serializer.CodeCredentialInvalid, "Invalid token type", nil)
|
||||
}
|
||||
|
||||
uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID)
|
||||
if err != nil {
|
||||
return false, serializer.NewError(serializer.CodeNotFound, "User not found", err)
|
||||
}
|
||||
|
||||
util.WithValue(c, inventory.UserIDCtx{}, uid)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t *tokenAuth) Issue(ctx context.Context, u *ent.User) (*Token, error) {
|
||||
uidEncoded := hashid.EncodeUserID(t.idEncoder, u.ID)
|
||||
tokenSettings := t.s.TokenAuth(ctx)
|
||||
issueDate := time.Now()
|
||||
accessTokenExpired := time.Now().Add(tokenSettings.AccessTokenTTL)
|
||||
refreshTokenExpired := time.Now().Add(tokenSettings.RefreshTokenTTL)
|
||||
|
||||
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: uidEncoded,
|
||||
NotBefore: jwt.NewNumericDate(issueDate),
|
||||
ExpiresAt: jwt.NewNumericDate(accessTokenExpired),
|
||||
},
|
||||
}).SignedString(t.secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faield to sign access token: %w", err)
|
||||
}
|
||||
|
||||
userHash := t.hashUserState(ctx, u)
|
||||
refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
||||
TokenType: TokenTypeRefresh,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: uidEncoded,
|
||||
NotBefore: jwt.NewNumericDate(issueDate),
|
||||
ExpiresAt: jwt.NewNumericDate(refreshTokenExpired),
|
||||
},
|
||||
StateHash: userHash[:],
|
||||
}).SignedString(t.secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faield to sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
AccessExpires: accessTokenExpired,
|
||||
RefreshExpires: refreshTokenExpired,
|
||||
UID: u.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// hashUserState returns a hash string for user state for critical fields, it is used
|
||||
// to detect refresh token revocation after user changed password.
|
||||
func (t *tokenAuth) hashUserState(ctx context.Context, u *ent.User) [32]byte {
|
||||
return sha256.Sum256([]byte(fmt.Sprintf("%s/%s/%s", u.Email, u.Password, t.s.SiteBasic(ctx).ID)))
|
||||
}
|
||||
25
pkg/auth/requestinfo/requestinfo.go
Normal file
25
pkg/auth/requestinfo/requestinfo.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package requestinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// RequestInfoCtx context key for RequestInfo
|
||||
type RequestInfoCtx struct{}
|
||||
|
||||
// RequestInfoFromContext retrieves RequestInfo from context
|
||||
func RequestInfoFromContext(ctx context.Context) *RequestInfo {
|
||||
v, ok := ctx.Value(RequestInfoCtx{}).(*RequestInfo)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// RequestInfo store request info for audit
|
||||
type RequestInfo struct {
|
||||
Host string
|
||||
IP string
|
||||
UserAgent string
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/duo-labs/webauthn/webauthn"
|
||||
)
|
||||
|
||||
// NewAuthnInstance 新建Authn实例
|
||||
func NewAuthnInstance() (*webauthn.WebAuthn, error) {
|
||||
base := model.GetSiteURL()
|
||||
return webauthn.New(&webauthn.Config{
|
||||
RPDisplayName: model.GetSettingByName("siteName"), // Display Name for your site
|
||||
RPID: base.Hostname(), // Generally the FQDN for your site
|
||||
RPOrigin: base.String(), // The origin URL for WebAuthn requests
|
||||
})
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
cache.Set("setting_siteURL", "http://cloudreve.org", 0)
|
||||
cache.Set("setting_siteName", "Cloudreve", 0)
|
||||
res, err := NewAuthnInstance()
|
||||
asserts.NotNil(res)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
86
pkg/boolset/boolset.go
Normal file
86
pkg/boolset/boolset.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package boolset
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrValueNotSupported = errors.New("value not supported")
|
||||
)
|
||||
|
||||
type BooleanSet []byte
|
||||
|
||||
// FromString convert from base64 encoded boolset.
|
||||
func FromString(data string) (*BooleanSet, error) {
|
||||
raw, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := BooleanSet(raw)
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
func (b *BooleanSet) UnmarshalBinary(data []byte) error {
|
||||
*b = make(BooleanSet, len(data))
|
||||
copy(*b, data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BooleanSet) MarshalBinary() (data []byte, err error) {
|
||||
return *b, nil
|
||||
}
|
||||
|
||||
func (b *BooleanSet) String() (data string, err error) {
|
||||
raw, err := b.MarshalBinary()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(raw), nil
|
||||
}
|
||||
|
||||
func (b *BooleanSet) Enabled(flag int) bool {
|
||||
if flag >= len(*b)*8 {
|
||||
return false
|
||||
}
|
||||
|
||||
return (*b)[flag/8]&(1<<uint(flag%8)) != 0
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer method.
|
||||
func (b *BooleanSet) Value() (driver.Value, error) {
|
||||
return b.MarshalBinary()
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner method.
|
||||
func (b *BooleanSet) Scan(src any) error {
|
||||
srcByte, ok := src.([]byte)
|
||||
if !ok {
|
||||
return ErrValueNotSupported
|
||||
}
|
||||
return b.UnmarshalBinary(srcByte)
|
||||
}
|
||||
|
||||
// Sets set BooleanSet values in batch.
|
||||
func Sets[T constraints.Integer](val map[T]bool, bs *BooleanSet) {
|
||||
for flag, v := range val {
|
||||
Set(flag, v, bs)
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a BooleanSet value.
|
||||
func Set[T constraints.Integer](flag T, enabled bool, bs *BooleanSet) {
|
||||
if len(*bs) < int(flag/8)+1 {
|
||||
*bs = append(*bs, make([]byte, int(flag/8)+1-len(*bs))...)
|
||||
}
|
||||
|
||||
if enabled {
|
||||
(*bs)[flag/8] |= 1 << uint(flag%8)
|
||||
} else {
|
||||
(*bs)[flag/8] &= ^(1 << uint(flag%8))
|
||||
}
|
||||
}
|
||||
1
pkg/boolset/boolset_test.go
Normal file
1
pkg/boolset/boolset_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package boolset
|
||||
85
pkg/cache/driver.go
vendored
85
pkg/cache/driver.go
vendored
@@ -2,102 +2,35 @@ package cache
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(map[string]itemWithTTL{})
|
||||
}
|
||||
|
||||
// Store 缓存存储器
|
||||
var Store Driver = NewMemoStore()
|
||||
|
||||
// Init 初始化缓存
|
||||
func Init() {
|
||||
if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode {
|
||||
Store = NewRedisStore(
|
||||
10,
|
||||
conf.RedisConfig.Network,
|
||||
conf.RedisConfig.Server,
|
||||
conf.RedisConfig.User,
|
||||
conf.RedisConfig.Password,
|
||||
conf.RedisConfig.DB,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Restore restores cache from given disk file
|
||||
func Restore(persistFile string) {
|
||||
if err := Store.Restore(persistFile); err != nil {
|
||||
util.Log().Warning("Failed to restore cache from disk: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func InitSlaveOverwrites() {
|
||||
err := Store.Sets(conf.OptionOverwrite, "setting_")
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to overwrite database setting: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Driver 键值缓存存储容器
|
||||
type Driver interface {
|
||||
// 设置值,ttl为过期时间,单位为秒
|
||||
Set(key string, value interface{}, ttl int) error
|
||||
Set(key string, value any, ttl int) error
|
||||
|
||||
// 取值,并返回是否成功
|
||||
Get(key string) (interface{}, bool)
|
||||
Get(key string) (any, bool)
|
||||
|
||||
// 批量取值,返回成功取值的map即不存在的值
|
||||
Gets(keys []string, prefix string) (map[string]interface{}, []string)
|
||||
Gets(keys []string, prefix string) (map[string]any, []string)
|
||||
|
||||
// 批量设置值,所有的key都会加上prefix前缀
|
||||
Sets(values map[string]interface{}, prefix string) error
|
||||
Sets(values map[string]any, prefix string) error
|
||||
|
||||
// 删除值
|
||||
Delete(keys []string, prefix string) error
|
||||
// Delete values by [Prefix + key]. If no ket is presented, all keys with given prefix will be deleted.
|
||||
Delete(prefix string, keys ...string) error
|
||||
|
||||
// Save in-memory cache to disk
|
||||
Persist(path string) error
|
||||
|
||||
// Restore cache from disk
|
||||
Restore(path string) error
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func Set(key string, value interface{}, ttl int) error {
|
||||
return Store.Set(key, value, ttl)
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func Get(key string) (interface{}, bool) {
|
||||
return Store.Get(key)
|
||||
}
|
||||
|
||||
// Deletes 删除值
|
||||
func Deletes(keys []string, prefix string) error {
|
||||
return Store.Delete(keys, prefix)
|
||||
}
|
||||
|
||||
// GetSettings 根据名称批量获取设置项缓存
|
||||
func GetSettings(keys []string, prefix string) (map[string]string, []string) {
|
||||
raw, miss := Store.Gets(keys, prefix)
|
||||
|
||||
res := make(map[string]string, len(raw))
|
||||
for k, v := range raw {
|
||||
res[k] = v.(string)
|
||||
}
|
||||
|
||||
return res, miss
|
||||
}
|
||||
|
||||
// SetSettings 批量设置站点设置缓存
|
||||
func SetSettings(values map[string]string, prefix string) error {
|
||||
var toBeSet = make(map[string]interface{}, len(values))
|
||||
for key, value := range values {
|
||||
toBeSet[key] = interface{}(value)
|
||||
}
|
||||
return Store.Sets(toBeSet, prefix)
|
||||
|
||||
// Remove all entries
|
||||
DeleteAll() error
|
||||
}
|
||||
|
||||
8
pkg/cache/driver_test.go
vendored
8
pkg/cache/driver_test.go
vendored
@@ -59,11 +59,3 @@ func TestInit(t *testing.T) {
|
||||
Init()
|
||||
})
|
||||
}
|
||||
|
||||
func TestInitSlaveOverwrites(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
asserts.NotPanics(func() {
|
||||
InitSlaveOverwrites()
|
||||
})
|
||||
}
|
||||
|
||||
52
pkg/cache/memo.go
vendored
52
pkg/cache/memo.go
vendored
@@ -3,11 +3,12 @@ package cache
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// MemoStore 内存存储驱动
|
||||
@@ -35,7 +36,7 @@ func newItem(value interface{}, expires int) itemWithTTL {
|
||||
}
|
||||
|
||||
// getValue 从itemWithTTL中取值
|
||||
func getValue(item interface{}, ok bool) (interface{}, bool) {
|
||||
func getValue(item any, ok bool) (any, bool) {
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
@@ -55,7 +56,7 @@ func getValue(item interface{}, ok bool) (interface{}, bool) {
|
||||
|
||||
// GarbageCollect 回收已过期的缓存
|
||||
func (store *MemoStore) GarbageCollect() {
|
||||
store.Store.Range(func(key, value interface{}) bool {
|
||||
store.Store.Range(func(key, value any) bool {
|
||||
if item, ok := value.(itemWithTTL); ok {
|
||||
if item.Expires > 0 && item.Expires < time.Now().Unix() {
|
||||
util.Log().Debug("Cache %q is garbage collected.", key.(string))
|
||||
@@ -67,25 +68,33 @@ func (store *MemoStore) GarbageCollect() {
|
||||
}
|
||||
|
||||
// NewMemoStore 新建内存存储
|
||||
func NewMemoStore() *MemoStore {
|
||||
return &MemoStore{
|
||||
func NewMemoStore(persistFile string, l logging.Logger) *MemoStore {
|
||||
store := &MemoStore{
|
||||
Store: &sync.Map{},
|
||||
}
|
||||
|
||||
if persistFile != "" {
|
||||
if err := store.Restore(persistFile); err != nil {
|
||||
l.Warning("Failed to restore cache from disk: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Set 存储值
|
||||
func (store *MemoStore) Set(key string, value interface{}, ttl int) error {
|
||||
func (store *MemoStore) Set(key string, value any, ttl int) error {
|
||||
store.Store.Store(key, newItem(value, ttl))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 取值
|
||||
func (store *MemoStore) Get(key string) (interface{}, bool) {
|
||||
func (store *MemoStore) Get(key string) (any, bool) {
|
||||
return getValue(store.Store.Load(key))
|
||||
}
|
||||
|
||||
// Gets 批量取值
|
||||
func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
|
||||
func (store *MemoStore) Gets(keys []string, prefix string) (map[string]any, []string) {
|
||||
var res = make(map[string]interface{})
|
||||
var notFound = make([]string, 0, len(keys))
|
||||
|
||||
@@ -101,7 +110,7 @@ func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface
|
||||
}
|
||||
|
||||
// Sets 批量设置值
|
||||
func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error {
|
||||
func (store *MemoStore) Sets(values map[string]any, prefix string) error {
|
||||
for key, value := range values {
|
||||
store.Store.Store(prefix+key, newItem(value, 0))
|
||||
}
|
||||
@@ -109,17 +118,27 @@ func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error
|
||||
}
|
||||
|
||||
// Delete 批量删除值
|
||||
func (store *MemoStore) Delete(keys []string, prefix string) error {
|
||||
func (store *MemoStore) Delete(prefix string, keys ...string) error {
|
||||
for _, key := range keys {
|
||||
store.Store.Delete(prefix + key)
|
||||
}
|
||||
|
||||
// No key is presented, delete all entries with given prefix.
|
||||
if len(keys) == 0 {
|
||||
store.Store.Range(func(key, value any) bool {
|
||||
if k, ok := key.(string); ok && strings.HasPrefix(k, prefix) {
|
||||
store.Store.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist write memory store into cache
|
||||
func (store *MemoStore) Persist(path string) error {
|
||||
persisted := make(map[string]itemWithTTL)
|
||||
store.Store.Range(func(key, value interface{}) bool {
|
||||
store.Store.Range(func(key, value any) bool {
|
||||
v, ok := store.Store.Load(key)
|
||||
if _, ok := getValue(v, ok); ok {
|
||||
persisted[key.(string)] = v.(itemWithTTL)
|
||||
@@ -173,3 +192,12 @@ func (store *MemoStore) Restore(path string) error {
|
||||
util.Log().Info("Restored %d items from %q into memory cache.", loaded, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoStore) DeleteAll() error {
|
||||
store.Store.Range(func(key any, value any) bool {
|
||||
store.Store.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
44
pkg/cache/memo_test.go
vendored
44
pkg/cache/memo_test.go
vendored
@@ -2,7 +2,6 @@ package cache
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -146,46 +145,3 @@ func TestMemoStore_GarbageCollect(t *testing.T) {
|
||||
_, ok := store.Get("test")
|
||||
asserts.False(ok)
|
||||
}
|
||||
|
||||
func TestMemoStore_PersistFailed(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
store := NewMemoStore()
|
||||
type testStruct struct{ v string }
|
||||
store.Set("test", 1, 0)
|
||||
store.Set("test2", testStruct{v: "test"}, 0)
|
||||
err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed"))
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
func TestMemoStore_PersistAndRestore(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
store := NewMemoStore()
|
||||
store.Set("test", 1, 0)
|
||||
// already expired
|
||||
store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1})
|
||||
// expired after persist
|
||||
store.Set("test3", 1, 1)
|
||||
temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")
|
||||
|
||||
// Persist
|
||||
err := store.Persist(temp)
|
||||
a.NoError(err)
|
||||
a.FileExists(temp)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// Restore
|
||||
store2 := NewMemoStore()
|
||||
err = store2.Restore(temp)
|
||||
a.NoError(err)
|
||||
test, testOk := store2.Get("test")
|
||||
a.EqualValues(1, test)
|
||||
a.True(testOk)
|
||||
test2, test2Ok := store2.Get("test2")
|
||||
a.Nil(test2)
|
||||
a.False(test2Ok)
|
||||
test3, test3Ok := store2.Get("test3")
|
||||
a.Nil(test3)
|
||||
a.False(test3Ok)
|
||||
|
||||
a.NoFileExists(temp)
|
||||
}
|
||||
|
||||
46
pkg/cache/redis.go
vendored
46
pkg/cache/redis.go
vendored
@@ -3,10 +3,10 @@ package cache
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ type item struct {
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func serializer(value interface{}) ([]byte, error) {
|
||||
func serializer(value any) ([]byte, error) {
|
||||
var buffer bytes.Buffer
|
||||
enc := gob.NewEncoder(&buffer)
|
||||
storeValue := item{
|
||||
@@ -32,7 +32,7 @@ func serializer(value interface{}) ([]byte, error) {
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func deserializer(value []byte) (interface{}, error) {
|
||||
func deserializer(value []byte) (any, error) {
|
||||
var res item
|
||||
buffer := bytes.NewReader(value)
|
||||
dec := gob.NewDecoder(buffer)
|
||||
@@ -44,7 +44,7 @@ func deserializer(value []byte) (interface{}, error) {
|
||||
}
|
||||
|
||||
// NewRedisStore 创建新的redis存储
|
||||
func NewRedisStore(size int, network, address, user, password, database string) *RedisStore {
|
||||
func NewRedisStore(l logging.Logger, size int, network, address, user, password, database string) *RedisStore {
|
||||
return &RedisStore{
|
||||
pool: &redis.Pool{
|
||||
MaxIdle: size,
|
||||
@@ -63,11 +63,11 @@ func NewRedisStore(size int, network, address, user, password, database string)
|
||||
network,
|
||||
address,
|
||||
redis.DialDatabase(db),
|
||||
redis.DialUsername(user),
|
||||
redis.DialPassword(password),
|
||||
redis.DialUsername(user),
|
||||
)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to create Redis connection: %s", err)
|
||||
l.Panic("Failed to create Redis connection: %s", err)
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
@@ -76,7 +76,7 @@ func NewRedisStore(size int, network, address, user, password, database string)
|
||||
}
|
||||
|
||||
// Set 存储值
|
||||
func (store *RedisStore) Set(key string, value interface{}, ttl int) error {
|
||||
func (store *RedisStore) Set(key string, value any, ttl int) error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
|
||||
@@ -103,7 +103,7 @@ func (store *RedisStore) Set(key string, value interface{}, ttl int) error {
|
||||
}
|
||||
|
||||
// Get 取值
|
||||
func (store *RedisStore) Get(key string) (interface{}, bool) {
|
||||
func (store *RedisStore) Get(key string) (any, bool) {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
@@ -125,7 +125,7 @@ func (store *RedisStore) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
|
||||
// Gets 批量取值
|
||||
func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
|
||||
func (store *RedisStore) Gets(keys []string, prefix string) (map[string]any, []string) {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
@@ -142,7 +142,7 @@ func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interfac
|
||||
return nil, keys
|
||||
}
|
||||
|
||||
var res = make(map[string]interface{})
|
||||
var res = make(map[string]any)
|
||||
var missed = make([]string, 0, len(keys))
|
||||
|
||||
for key, value := range v {
|
||||
@@ -158,13 +158,13 @@ func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interfac
|
||||
}
|
||||
|
||||
// Sets 批量设置值
|
||||
func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error {
|
||||
func (store *RedisStore) Sets(values map[string]any, prefix string) error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
return rc.Err()
|
||||
}
|
||||
var setValues = make(map[string]interface{})
|
||||
var setValues = make(map[string]any)
|
||||
|
||||
// 编码待设置值
|
||||
for key, value := range values {
|
||||
@@ -184,7 +184,7 @@ func (store *RedisStore) Sets(values map[string]interface{}, prefix string) erro
|
||||
}
|
||||
|
||||
// Delete 批量删除给定的键
|
||||
func (store *RedisStore) Delete(keys []string, prefix string) error {
|
||||
func (store *RedisStore) Delete(prefix string, keys ...string) error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
@@ -196,10 +196,24 @@ func (store *RedisStore) Delete(keys []string, prefix string) error {
|
||||
keys[i] = prefix + keys[i]
|
||||
}
|
||||
|
||||
_, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...)
|
||||
if err != nil {
|
||||
return err
|
||||
// No key is presented, delete all keys with given prefix
|
||||
if len(keys) == 0 {
|
||||
// Fetch all key with given prefix
|
||||
allPrefixKeys, err := redis.Strings(rc.Do("KEYS", prefix+"*"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys = allPrefixKeys
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
_, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
10
pkg/cache/redis_test.go
vendored
10
pkg/cache/redis_test.go
vendored
@@ -13,16 +13,16 @@ import (
|
||||
func TestNewRedisStore(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
store := NewRedisStore(10, "tcp", "", "", "", "0")
|
||||
store := NewRedisStore(10, "tcp", "", "", "0")
|
||||
asserts.NotNil(store)
|
||||
|
||||
asserts.Panics(func() {
|
||||
store.pool.Dial()
|
||||
})
|
||||
conn, err := store.pool.Dial()
|
||||
asserts.Nil(conn)
|
||||
asserts.Error(err)
|
||||
|
||||
testConn := redigomock.NewConn()
|
||||
cmd := testConn.Command("PING").Expect("PONG")
|
||||
err := store.pool.TestOnBorrow(testConn, time.Now())
|
||||
err = store.pool.TestOnBorrow(testConn, time.Now())
|
||||
if testConn.Stats(cmd) != 1 {
|
||||
fmt.Println("Command was not used")
|
||||
return
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DefaultController Controller
|
||||
|
||||
// Controller controls communications between master and slave
|
||||
type Controller interface {
|
||||
// Handle heartbeat sent from master
|
||||
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
|
||||
|
||||
// Get Aria2 Instance by master node ID
|
||||
GetAria2Instance(string) (common.Aria2, error)
|
||||
|
||||
// Send event change message to master node
|
||||
SendNotification(string, string, mq.Message) error
|
||||
|
||||
// Submit async task into task pool
|
||||
SubmitTask(string, interface{}, string, func(interface{})) error
|
||||
|
||||
// Get master node info
|
||||
GetMasterInfo(string) (*MasterInfo, error)
|
||||
|
||||
// Get master Oauth based policy credential
|
||||
GetPolicyOauthToken(string, uint) (string, error)
|
||||
}
|
||||
|
||||
type slaveController struct {
|
||||
masters map[string]MasterInfo
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// info of master node
|
||||
type MasterInfo struct {
|
||||
ID string
|
||||
TTL int
|
||||
URL *url.URL
|
||||
// used to invoke aria2 rpc calls
|
||||
Instance Node
|
||||
Client request.Client
|
||||
|
||||
jobTracker map[string]bool
|
||||
}
|
||||
|
||||
func InitController() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
gob.Register(rpc.StatusInfo{})
|
||||
}
|
||||
|
||||
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
req.Node.AfterFind()
|
||||
|
||||
// close old node if exist
|
||||
origin, ok := c.masters[req.SiteID]
|
||||
|
||||
if (ok && req.IsUpdate) || !ok {
|
||||
if ok {
|
||||
origin.Instance.Kill()
|
||||
}
|
||||
|
||||
masterUrl, err := url.Parse(req.SiteURL)
|
||||
if err != nil {
|
||||
return serializer.NodePingResp{}, err
|
||||
}
|
||||
|
||||
c.masters[req.SiteID] = MasterInfo{
|
||||
ID: req.SiteID,
|
||||
URL: masterUrl,
|
||||
TTL: req.CredentialTTL,
|
||||
Client: request.NewClient(
|
||||
request.WithEndpoint(masterUrl.String()),
|
||||
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
}, int64(req.CredentialTTL)),
|
||||
),
|
||||
jobTracker: make(map[string]bool),
|
||||
Instance: NewNodeFromDBModel(&model.Node{
|
||||
Model: gorm.Model{ID: req.Node.ID},
|
||||
MasterKey: req.Node.MasterKey,
|
||||
Type: model.MasterNodeType,
|
||||
Aria2Enabled: req.Node.Aria2Enabled,
|
||||
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return node.Instance.GetAria2Instance(), nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
body := bytes.Buffer{}
|
||||
enc := gob.NewEncoder(&body)
|
||||
if err := enc.Encode(&msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"PUT",
|
||||
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
|
||||
&body,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
if _, ok := node.jobTracker[hash]; ok {
|
||||
// 任务已存在,直接返回
|
||||
return nil
|
||||
}
|
||||
|
||||
node.jobTracker[hash] = true
|
||||
submitter(job)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetMasterInfo 获取主机节点信息
|
||||
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证
|
||||
func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"GET",
|
||||
fmt.Sprintf("/api/v3/slave/credential/%d", policyID),
|
||||
nil,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return "", ErrMasterNotFound
|
||||
}
|
||||
@@ -1,385 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitController(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
InitController()
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlaveController_HandleHeartBeat(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
|
||||
// first heart beat
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "2",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
a.Len(c.masters, 2)
|
||||
}
|
||||
|
||||
// second heart beat, no fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Empty(c.masters["1"].URL)
|
||||
}
|
||||
|
||||
// second heart beat, fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
|
||||
// second heart beat, fresh, url illegal
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: string([]byte{0x7f}),
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.Error(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
}
|
||||
|
||||
type nodeMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n nodeMock) Init(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsFeatureEnabled(feature string) bool {
|
||||
args := n.Called(feature)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
n.Called(callback)
|
||||
}
|
||||
|
||||
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
args := n.Called(req)
|
||||
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsActive() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) GetAria2Instance() common.Aria2 {
|
||||
args := n.Called()
|
||||
return args.Get(0).(common.Aria2)
|
||||
}
|
||||
|
||||
func (n nodeMock) ID() uint {
|
||||
args := n.Called()
|
||||
return args.Get(0).(uint)
|
||||
}
|
||||
|
||||
func (n nodeMock) Kill() {
|
||||
n.Called()
|
||||
}
|
||||
|
||||
func (n nodeMock) IsMater() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) MasterAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) SlaveAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) DBModel() *model.Node {
|
||||
args := n.Called()
|
||||
return args.Get(0).(*model.Node)
|
||||
}
|
||||
|
||||
func TestSlaveController_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Instance: mockNode},
|
||||
},
|
||||
}
|
||||
|
||||
// node node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("2")
|
||||
a.Nil(res)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
}
|
||||
|
||||
// node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("1")
|
||||
a.NotNil(res)
|
||||
a.NoError(err)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type requestMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||
}
|
||||
|
||||
func TestSlaveController_SendNotification(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
|
||||
}
|
||||
|
||||
// gob encode error
|
||||
{
|
||||
type randomType struct{}
|
||||
a.Error(c.SendNotification("1", "", mq.Message{
|
||||
Content: randomType{},
|
||||
}))
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Error(c.SendNotification("1", "s1", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_SubmitTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {
|
||||
jobTracker: map[string]bool{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
submitted := false
|
||||
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
|
||||
submitted = true
|
||||
}))
|
||||
a.True(submitted)
|
||||
}
|
||||
|
||||
// job already submitted
|
||||
{
|
||||
submitted := false
|
||||
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
|
||||
submitted = true
|
||||
}))
|
||||
a.False(submitted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_GetMasterInfo(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
res, err := c.GetMasterInfo("2")
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
res, err := c.GetMasterInfo("1")
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_GetOneDriveToken(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
res, err := c.GetPolicyOauthToken("2", 1)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetPolicyOauthToken("1", 1)
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetPolicyOauthToken("1", 1)
|
||||
a.Equal(1, err.(serializer.AppError).Code)
|
||||
a.Empty(res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetPolicyOauthToken("1", 1)
|
||||
a.NoError(err)
|
||||
a.Equal("expected", res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
|
||||
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil)
|
||||
)
|
||||
@@ -1,272 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
deleteTempFileDuration = 60 * time.Second
|
||||
statusRetryDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type MasterNode struct {
|
||||
Model *model.Node
|
||||
aria2RPC rpcService
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type rpcService struct {
|
||||
Caller rpc.Client
|
||||
Initialized bool
|
||||
|
||||
retryDuration time.Duration
|
||||
deletePaddingDuration time.Duration
|
||||
parent *MasterNode
|
||||
options *clientOptions
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *MasterNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.aria2RPC.parent = node
|
||||
node.aria2RPC.retryDuration = statusRetryDuration
|
||||
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
|
||||
node.lock.Unlock()
|
||||
|
||||
node.lock.RLock()
|
||||
if node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return
|
||||
}
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (node *MasterNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
return &serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (node *MasterNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *MasterNode) IsActive() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Kill 结束aria2请求
|
||||
func (node *MasterNode) Kill() {
|
||||
if node.aria2RPC.Caller != nil {
|
||||
node.aria2RPC.Caller.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取主机Aria2实例
|
||||
func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
if !node.aria2RPC.Initialized {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
defer node.lock.RUnlock()
|
||||
return &node.aria2RPC
|
||||
}
|
||||
|
||||
func (node *MasterNode) IsMater() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (node *MasterNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
func (r *rpcService) Init() error {
|
||||
r.parent.lock.Lock()
|
||||
defer r.parent.lock.Unlock()
|
||||
r.Initialized = false
|
||||
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if r.Caller != nil {
|
||||
r.Caller.Close()
|
||||
}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
|
||||
return err
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
|
||||
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to parse aria2 options: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.options = &clientOptions{
|
||||
Options: globalOptions,
|
||||
}
|
||||
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
|
||||
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
|
||||
|
||||
r.Caller = caller
|
||||
r.Initialized = err == nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||
r.parent.lock.RLock()
|
||||
// 生成存储路径
|
||||
guid, _ := uuid.NewV4()
|
||||
path := filepath.Join(
|
||||
r.parent.Model.Aria2OptionsSerialized.TempPath,
|
||||
"aria2",
|
||||
guid.String(),
|
||||
)
|
||||
r.parent.lock.RUnlock()
|
||||
|
||||
// 创建下载任务
|
||||
options := map[string]interface{}{
|
||||
"dir": path,
|
||||
}
|
||||
for k, v := range r.options.Options {
|
||||
options[k] = v
|
||||
}
|
||||
for k, v := range groupOptions {
|
||||
options[k] = v
|
||||
}
|
||||
|
||||
gid, err := r.Caller.AddURI(task.Source, options)
|
||||
if err != nil || gid == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gid, nil
|
||||
}
|
||||
|
||||
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
res, err := r.Caller.TellStatus(task.GID)
|
||||
if err != nil {
|
||||
// 失败后重试
|
||||
util.Log().Debug("Failed to get download task status, please retry later: %s", err)
|
||||
time.Sleep(r.retryDuration)
|
||||
res, err = r.Caller.TellStatus(task.GID)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *rpcService) Cancel(task *model.Download) error {
|
||||
// 取消下载任务
|
||||
_, err := r.Caller.Remove(task.GID)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) Select(task *model.Download, files []int) error {
|
||||
var selected = make([]string, len(files))
|
||||
for i := 0; i < len(files); i++ {
|
||||
selected[i] = strconv.Itoa(files[i])
|
||||
}
|
||||
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) GetConfig() model.Aria2Option {
|
||||
r.parent.lock.RLock()
|
||||
defer r.parent.lock.RUnlock()
|
||||
|
||||
return r.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *rpcService) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
// 避免被aria2占用,异步执行删除
|
||||
go func(d time.Duration, src string) {
|
||||
time.Sleep(d)
|
||||
err := os.RemoveAll(src)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
|
||||
}
|
||||
}(s.deletePaddingDuration, task.Parent)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMasterNode_Init(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{}
|
||||
m.Init(&model.Node{Status: model.NodeSuspend})
|
||||
a.Equal(model.NodeSuspend, m.DBModel().Status)
|
||||
m.Init(&model.Node{Aria2Enabled: true})
|
||||
}
|
||||
|
||||
func TestMasterNode_DummyMethods(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Model.ID = 5
|
||||
a.Equal(m.Model.ID, m.ID())
|
||||
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
|
||||
a.True(m.IsActive())
|
||||
a.True(m.IsMater())
|
||||
|
||||
m.SubscribeStatusChange(func(isActive bool, id uint) {})
|
||||
}
|
||||
|
||||
func TestMasterNode_IsFeatureEnabled(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.False(m.IsFeatureEnabled("aria2"))
|
||||
a.False(m.IsFeatureEnabled("random"))
|
||||
m.Model.Aria2Enabled = true
|
||||
a.True(m.IsFeatureEnabled("aria2"))
|
||||
}
|
||||
|
||||
func TestMasterNode_AuthInstance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.MasterAuthInstance())
|
||||
a.NotNil(m.SlaveAuthInstance())
|
||||
}
|
||||
|
||||
func TestMasterNode_Kill(t *testing.T) {
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Kill()
|
||||
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
m.Kill()
|
||||
}
|
||||
|
||||
func TestMasterNode_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
aria2RPC: rpcService{},
|
||||
}
|
||||
|
||||
m.aria2RPC.parent = m
|
||||
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.Model.Aria2Enabled = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.aria2RPC.Initialized = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
}
|
||||
|
||||
func TestRpcService_Init(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{
|
||||
Aria2OptionsSerialized: model.Aria2Option{
|
||||
Options: "{",
|
||||
},
|
||||
},
|
||||
aria2RPC: rpcService{},
|
||||
}
|
||||
m.aria2RPC.parent = m
|
||||
|
||||
// failed to decode address
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
|
||||
a.Error(m.aria2RPC.Init())
|
||||
}
|
||||
|
||||
// failed to decode options
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = ""
|
||||
a.Error(m.aria2RPC.Init())
|
||||
}
|
||||
|
||||
// failed to initialized
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = ""
|
||||
m.Model.Aria2OptionsSerialized.Options = "{}"
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
a.Error(m.aria2RPC.Init())
|
||||
a.False(m.aria2RPC.Initialized)
|
||||
}
|
||||
}
|
||||
|
||||
func getTestRPCNode() *MasterNode {
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{
|
||||
Aria2OptionsSerialized: model.Aria2Option{},
|
||||
},
|
||||
aria2RPC: rpcService{
|
||||
options: &clientOptions{
|
||||
Options: map[string]interface{}{"1": "1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
m.aria2RPC.parent = m
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
return m
|
||||
}
|
||||
|
||||
func TestRpcService_CreateTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
func TestRpcService_Status(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
res, err := m.aria2RPC.Status(&model.Download{})
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
func TestRpcService_Cancel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
a.Error(m.aria2RPC.Cancel(&model.Download{}))
|
||||
}
|
||||
|
||||
func TestRpcService_Select(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
a.NotNil(m.aria2RPC.GetConfig())
|
||||
a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
|
||||
}
|
||||
|
||||
func TestRpcService_DeleteTempFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
fdName := "TestRpcService_DeleteTempFile"
|
||||
a.NoError(os.Mkdir(fdName, 0644))
|
||||
|
||||
a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
a.False(util.Exists(fdName))
|
||||
}
|
||||
@@ -1,60 +1,413 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/node"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/task"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/qbittorrent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/slave"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/queue"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
// Init a node from database model
|
||||
Init(node *model.Node)
|
||||
type (
|
||||
Node interface {
|
||||
fs.StatelessUploadManager
|
||||
ID() int
|
||||
Name() string
|
||||
IsMaster() bool
|
||||
// CreateTask creates a task on the node. It does not have effect on master node.
|
||||
CreateTask(ctx context.Context, taskType string, state string) (int, error)
|
||||
// GetTask returns the task summary of the task with the given id.
|
||||
GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error)
|
||||
// CleanupFolders cleans up the given folders on the node.
|
||||
CleanupFolders(ctx context.Context, folders ...string) error
|
||||
// AuthInstance returns the auth instance for the node.
|
||||
AuthInstance() auth.Auth
|
||||
// CreateDownloader creates a downloader instance from the node for remote download tasks.
|
||||
CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error)
|
||||
// Settings returns the settings of the node.
|
||||
Settings(ctx context.Context) *types.NodeSetting
|
||||
}
|
||||
|
||||
// Check if given feature is enabled
|
||||
IsFeatureEnabled(feature string) bool
|
||||
// Request body for creating tasks on slave node
|
||||
CreateSlaveTask struct {
|
||||
Type string `json:"type"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// Subscribe node status change to a callback function
|
||||
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||
// Request body for cleaning up folders on slave node
|
||||
FolderCleanup struct {
|
||||
Path []string `json:"path" binding:"required"`
|
||||
}
|
||||
|
||||
// Ping the node
|
||||
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||
SlaveTaskSummary struct {
|
||||
Status task.Status `json:"status"`
|
||||
Error string `json:"error"`
|
||||
PrivateState string `json:"private_state"`
|
||||
Progress queue.Progresses `json:"progress,omitempty"`
|
||||
}
|
||||
|
||||
// Returns if the node is active
|
||||
IsActive() bool
|
||||
MasterSiteUrlCtx struct{}
|
||||
MasterSiteVersionCtx struct{}
|
||||
MasterSiteIDCtx struct{}
|
||||
SlaveNodeIDCtx struct{}
|
||||
masterNode struct {
|
||||
nodeBase
|
||||
client request.Client
|
||||
}
|
||||
)
|
||||
|
||||
// Get instances for aria2 calls
|
||||
GetAria2Instance() common.Aria2
|
||||
|
||||
// Returns unique id of this node
|
||||
ID() uint
|
||||
|
||||
// Kill node and recycle resources
|
||||
Kill()
|
||||
|
||||
// Returns if current node is master node
|
||||
IsMater() bool
|
||||
|
||||
// Get auth instance used to check RPC call from slave to master
|
||||
MasterAuthInstance() auth.Auth
|
||||
|
||||
// Get auth instance used to check RPC call from master to slave
|
||||
SlaveAuthInstance() auth.Auth
|
||||
|
||||
// Get node DB model
|
||||
DBModel() *model.Node
|
||||
func newNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) Node {
|
||||
if model.Type == node.TypeMaster {
|
||||
return newMasterNode(model, config, settings)
|
||||
}
|
||||
return newSlaveNode(ctx, model, config, settings)
|
||||
}
|
||||
|
||||
// Create new node from DB model
|
||||
func NewNodeFromDBModel(node *model.Node) Node {
|
||||
switch node.Type {
|
||||
case model.SlaveNodeType:
|
||||
slave := &SlaveNode{}
|
||||
slave.Init(node)
|
||||
return slave
|
||||
default:
|
||||
master := &MasterNode{}
|
||||
master.Init(node)
|
||||
return master
|
||||
func newMasterNode(model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *masterNode {
|
||||
n := &masterNode{
|
||||
nodeBase: nodeBase{
|
||||
model: model,
|
||||
},
|
||||
}
|
||||
|
||||
if config.System().Mode == conf.SlaveMode {
|
||||
n.client = request.NewClient(config,
|
||||
request.WithCorrelationID(),
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
[]byte(config.Slave().Secret),
|
||||
}, int64(config.Slave().SignatureTTL)),
|
||||
)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *masterNode) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "prepare")
|
||||
resp, err := b.client.Request(
|
||||
"PUT",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
uploadRequest := &fs.StatelessPrepareUploadResponse{}
|
||||
resp.GobDecode(uploadRequest)
|
||||
|
||||
return uploadRequest, nil
|
||||
}
|
||||
|
||||
func (b *masterNode) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "complete")
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *masterNode) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "failed")
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *masterNode) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error {
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
requestDst := routes.MasterStatelessUrl(MasterSiteUrlFromContext(ctx), "create")
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
requestDst.String(),
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithSlaveMeta(NodeIdFromContext(ctx)),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (b *masterNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
|
||||
return NewDownloader(ctx, c, settings, b.Settings(ctx))
|
||||
}
|
||||
|
||||
// NewDownloader creates a new downloader instance from the node for remote download tasks.
|
||||
func NewDownloader(ctx context.Context, c request.Client, settings setting.Provider, options *types.NodeSetting) (downloader.Downloader, error) {
|
||||
if options.Provider == types.DownloaderProviderQBittorrent {
|
||||
return qbittorrent.NewClient(logging.FromContext(ctx), c, settings, options.QBittorrentSetting)
|
||||
} else if options.Provider == types.DownloaderProviderAria2 {
|
||||
return aria2.New(logging.FromContext(ctx), settings, options.Aria2Setting), nil
|
||||
} else if options.Provider == "" {
|
||||
return nil, errors.New("downloader not configured for this node")
|
||||
} else {
|
||||
return nil, errors.New("unknown downloader provider")
|
||||
}
|
||||
}
|
||||
|
||||
type slaveNode struct {
|
||||
nodeBase
|
||||
client request.Client
|
||||
}
|
||||
|
||||
func newSlaveNode(ctx context.Context, model *ent.Node, config conf.ConfigProvider, settings setting.Provider) *slaveNode {
|
||||
siteBasic := settings.SiteBasic(ctx)
|
||||
return &slaveNode{
|
||||
nodeBase: nodeBase{
|
||||
model: model,
|
||||
},
|
||||
client: request.NewClient(config,
|
||||
request.WithCorrelationID(),
|
||||
request.WithSlaveMeta(model.ID),
|
||||
request.WithMasterMeta(siteBasic.ID, settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()),
|
||||
request.WithCredential(auth.HMACAuth{[]byte(model.SlaveKey)}, int64(settings.SlaveRequestSignTTL(ctx))),
|
||||
request.WithEndpoint(model.Server)),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *slaveNode) CreateTask(ctx context.Context, taskType string, state string) (int, error) {
|
||||
reqBody, err := json.Marshal(&CreateSlaveTask{
|
||||
Type: taskType,
|
||||
State: state,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := n.client.Request(
|
||||
"PUT",
|
||||
constants.APIPrefixSlave+"/task",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return 0, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
taskId := 0
|
||||
if resp.GobDecode(&taskId); taskId > 0 {
|
||||
return taskId, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unexpected response data: %v", resp.Data)
|
||||
}
|
||||
|
||||
func (n *slaveNode) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) {
|
||||
resp, err := n.client.Request(
|
||||
"GET",
|
||||
routes.SlaveGetTaskRoute(id, clearOnComplete),
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
summary := &SlaveTaskSummary{}
|
||||
resp.GobDecode(summary)
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func (b *slaveNode) CleanupFolders(ctx context.Context, folders ...string) error {
|
||||
args := &FolderCleanup{
|
||||
Path: folders,
|
||||
}
|
||||
reqBody, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := b.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/task/cleanup",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *slaveNode) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
|
||||
return slave.NewSlaveDownloader(b.client, b.Settings(ctx)), nil
|
||||
}
|
||||
|
||||
type nodeBase struct {
|
||||
model *ent.Node
|
||||
}
|
||||
|
||||
func (b *nodeBase) ID() int {
|
||||
return b.model.ID
|
||||
}
|
||||
|
||||
func (b *nodeBase) Name() string {
|
||||
return b.model.Name
|
||||
}
|
||||
|
||||
func (b *nodeBase) IsMaster() bool {
|
||||
return b.model.Type == node.TypeMaster
|
||||
}
|
||||
|
||||
func (b *nodeBase) CreateTask(ctx context.Context, taskType string, state string) (int, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) AuthInstance() auth.Auth {
|
||||
return auth.HMACAuth{[]byte(b.model.SlaveKey)}
|
||||
}
|
||||
|
||||
func (b *nodeBase) GetTask(ctx context.Context, id int, clearOnComplete bool) (*SlaveTaskSummary, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CleanupFolders(ctx context.Context, folders ...string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) PrepareUpload(ctx context.Context, args *fs.StatelessPrepareUploadService) (*fs.StatelessPrepareUploadResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CompleteUpload(ctx context.Context, args *fs.StatelessCompleteUploadService) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) OnUploadFailed(ctx context.Context, args *fs.StatelessOnUploadFailedService) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CreateFile(ctx context.Context, args *fs.StatelessCreateFileService) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) CreateDownloader(ctx context.Context, c request.Client, settings setting.Provider) (downloader.Downloader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *nodeBase) Settings(ctx context.Context) *types.NodeSetting {
|
||||
return b.model.Settings
|
||||
}
|
||||
|
||||
func NodeIdFromContext(ctx context.Context) int {
|
||||
nodeIdStr, ok := ctx.Value(SlaveNodeIDCtx{}).(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
nodeId, _ := strconv.Atoi(nodeIdStr)
|
||||
return nodeId
|
||||
}
|
||||
|
||||
func MasterSiteUrlFromContext(ctx context.Context) string {
|
||||
if u, ok := ctx.Value(MasterSiteUrlCtx{}).(string); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewNodeFromDBModel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
|
||||
Type: model.SlaveNodeType,
|
||||
}))
|
||||
a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
|
||||
Type: model.MasterNodeType,
|
||||
}))
|
||||
}
|
||||
@@ -1,190 +1,203 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/node"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var Default *NodePool
|
||||
|
||||
// 需要分类的节点组
|
||||
var featureGroup = []string{"aria2"}
|
||||
|
||||
// Pool 节点池
|
||||
type Pool interface {
|
||||
// Returns active node selected by given feature and load balancer
|
||||
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
|
||||
|
||||
// Returns node by ID
|
||||
GetNodeByID(id uint) Node
|
||||
|
||||
// Add given node into pool. If node existed, refresh node.
|
||||
Add(node *model.Node)
|
||||
|
||||
// Delete and kill node from pool by given node id
|
||||
Delete(id uint)
|
||||
type NodePool interface {
|
||||
// Upsert updates or inserts a node into the pool.
|
||||
Upsert(ctx context.Context, node *ent.Node)
|
||||
// Get returns a node with the given capability and preferred node id. `allowed` is a list of allowed node ids.
|
||||
// If `allowed` is empty, all nodes with the capability are considered.
|
||||
Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error)
|
||||
}
|
||||
|
||||
// NodePool 通用节点池
|
||||
type NodePool struct {
|
||||
active map[uint]Node
|
||||
inactive map[uint]Node
|
||||
type (
|
||||
weightedNodePool struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
featureMap map[string][]Node
|
||||
conf conf.ConfigProvider
|
||||
settings setting.Provider
|
||||
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Init 初始化从机节点池
|
||||
func Init() {
|
||||
Default = &NodePool{}
|
||||
Default.Init()
|
||||
if err := Default.initFromDB(); err != nil {
|
||||
util.Log().Warning("Failed to initialize node pool: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (pool *NodePool) Init() {
|
||||
pool.lock.Lock()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
pool.featureMap = make(map[string][]Node)
|
||||
pool.active = make(map[uint]Node)
|
||||
pool.inactive = make(map[uint]Node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) buildIndexMap() {
|
||||
pool.lock.Lock()
|
||||
for _, feature := range featureGroup {
|
||||
pool.featureMap[feature] = make([]Node, 0)
|
||||
nodes map[types.NodeCapability][]*nodeItem
|
||||
}
|
||||
|
||||
for _, v := range pool.active {
|
||||
for _, feature := range featureGroup {
|
||||
if v.IsFeatureEnabled(feature) {
|
||||
pool.featureMap[feature] = append(pool.featureMap[feature], v)
|
||||
nodeItem struct {
|
||||
node Node
|
||||
weight int
|
||||
current int
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoAvailableNode = fmt.Errorf("no available node found")
|
||||
|
||||
supportedCapabilities = []types.NodeCapability{
|
||||
types.NodeCapabilityNone,
|
||||
types.NodeCapabilityCreateArchive,
|
||||
types.NodeCapabilityExtractArchive,
|
||||
types.NodeCapabilityRemoteDownload,
|
||||
}
|
||||
)
|
||||
|
||||
func NewNodePool(ctx context.Context, l logging.Logger, config conf.ConfigProvider, settings setting.Provider,
|
||||
client inventory.NodeClient) (NodePool, error) {
|
||||
nodes, err := client.ListActiveNodes(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list active nodes: %w", err)
|
||||
}
|
||||
|
||||
pool := &weightedNodePool{
|
||||
nodes: make(map[types.NodeCapability][]*nodeItem),
|
||||
conf: config,
|
||||
settings: settings,
|
||||
}
|
||||
for _, node := range nodes {
|
||||
for _, capability := range supportedCapabilities {
|
||||
// If current capability is enabled, add it to pool slot.
|
||||
if capability == types.NodeCapabilityNone ||
|
||||
(node.Capabilities != nil && node.Capabilities.Enabled(int(capability))) {
|
||||
if _, ok := pool.nodes[capability]; !ok {
|
||||
pool.nodes[capability] = make([]*nodeItem, 0)
|
||||
}
|
||||
|
||||
l.Debug("Add node %q to capability slot %d with weight %d", node.Name, capability, node.Weight)
|
||||
pool.nodes[capability] = append(pool.nodes[capability], &nodeItem{
|
||||
node: newNode(ctx, node, config, settings),
|
||||
weight: node.Weight,
|
||||
current: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func (pool *NodePool) GetNodeByID(id uint) Node {
|
||||
pool.lock.RLock()
|
||||
defer pool.lock.RUnlock()
|
||||
func (p *weightedNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) {
|
||||
l := logging.FromContext(ctx)
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
if node, ok := pool.active[id]; ok {
|
||||
return node
|
||||
nodes, ok := p.nodes[capability]
|
||||
if !ok || len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode)
|
||||
}
|
||||
|
||||
return pool.inactive[id]
|
||||
}
|
||||
var selected *nodeItem
|
||||
|
||||
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
|
||||
util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive)
|
||||
var node Node
|
||||
pool.lock.Lock()
|
||||
if n, ok := pool.inactive[id]; ok {
|
||||
node = n
|
||||
delete(pool.inactive, id)
|
||||
} else {
|
||||
node = pool.active[id]
|
||||
delete(pool.active, id)
|
||||
}
|
||||
|
||||
if isActive {
|
||||
pool.active[id] = node
|
||||
} else {
|
||||
pool.inactive[id] = node
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
pool.buildIndexMap()
|
||||
}
|
||||
|
||||
func (pool *NodePool) initFromDB() error {
|
||||
nodes, err := model.GetNodesByStatus(model.NodeActive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pool.lock.Lock()
|
||||
for i := 0; i < len(nodes); i++ {
|
||||
pool.add(&nodes[i])
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
pool.buildIndexMap()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pool *NodePool) add(node *model.Node) {
|
||||
newNode := NewNodeFromDBModel(node)
|
||||
if newNode.IsActive() {
|
||||
pool.active[node.ID] = newNode
|
||||
} else {
|
||||
pool.inactive[node.ID] = newNode
|
||||
}
|
||||
|
||||
// 订阅节点状态变更
|
||||
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
|
||||
pool.nodeStatusChange(isActive, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (pool *NodePool) Add(node *model.Node) {
|
||||
pool.lock.Lock()
|
||||
defer pool.buildIndexMap()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
var (
|
||||
old Node
|
||||
ok bool
|
||||
)
|
||||
if old, ok = pool.active[node.ID]; !ok {
|
||||
old, ok = pool.inactive[node.ID]
|
||||
}
|
||||
if old != nil {
|
||||
go old.Init(node)
|
||||
return
|
||||
}
|
||||
|
||||
pool.add(node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) Delete(id uint) {
|
||||
pool.lock.Lock()
|
||||
defer pool.buildIndexMap()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
if node, ok := pool.active[id]; ok {
|
||||
node.Kill()
|
||||
delete(pool.active, id)
|
||||
return
|
||||
}
|
||||
|
||||
if node, ok := pool.inactive[id]; ok {
|
||||
node.Kill()
|
||||
delete(pool.inactive, id)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
|
||||
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
|
||||
pool.lock.RLock()
|
||||
defer pool.lock.RUnlock()
|
||||
if nodes, ok := pool.featureMap[feature]; ok {
|
||||
err, res := lb.NextPeer(nodes)
|
||||
if err == nil {
|
||||
return nil, res.(Node)
|
||||
if preferred > 0 {
|
||||
// First try to find the preferred node.
|
||||
for _, n := range nodes {
|
||||
if n.node.ID() == preferred {
|
||||
selected = n
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return err, nil
|
||||
if selected == nil {
|
||||
l.Debug("Preferred node %d not found, fallback to select a node with the least current weight", preferred)
|
||||
}
|
||||
}
|
||||
|
||||
return ErrFeatureNotExist, nil
|
||||
if selected == nil {
|
||||
// If no preferred one, or the preferred one is not available, select a node with the least current weight.
|
||||
|
||||
// Total weight of all items.
|
||||
var total int
|
||||
|
||||
// Loop through the list of items and add the item's weight to the current weight.
|
||||
// Also increment the total weight counter.
|
||||
var maxNode *nodeItem
|
||||
for _, item := range nodes {
|
||||
item.current += max(1, item.weight)
|
||||
total += max(1, item.weight)
|
||||
|
||||
// Select the item with max weight.
|
||||
if maxNode == nil || item.current > maxNode.current {
|
||||
maxNode = item
|
||||
}
|
||||
}
|
||||
|
||||
// Select the item with the max weight.
|
||||
selected = maxNode
|
||||
if selected == nil {
|
||||
return nil, fmt.Errorf("no node found with capability %d: %w", capability, ErrNoAvailableNode)
|
||||
}
|
||||
|
||||
l.Debug("Selected node %q with weight=%d, current=%d, total=%d", selected.node.Name(), selected.weight, maxNode.current, total)
|
||||
|
||||
// Reduce the current weight of the selected item by the total weight.
|
||||
maxNode.current -= total
|
||||
}
|
||||
|
||||
return selected.node, nil
|
||||
}
|
||||
|
||||
func (p *weightedNodePool) Upsert(ctx context.Context, n *ent.Node) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
for _, capability := range supportedCapabilities {
|
||||
_, index, found := lo.FindIndexOf(p.nodes[capability], func(i *nodeItem) bool {
|
||||
return i.node.ID() == n.ID
|
||||
})
|
||||
if capability == types.NodeCapabilityNone ||
|
||||
(n.Capabilities != nil && n.Capabilities.Enabled(int(capability))) {
|
||||
if n.Status != node.StatusActive && found {
|
||||
// Remove inactive node
|
||||
p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...)
|
||||
continue
|
||||
}
|
||||
|
||||
if found {
|
||||
p.nodes[capability][index].node = newNode(ctx, n, p.conf, p.settings)
|
||||
} else {
|
||||
p.nodes[capability] = append(p.nodes[capability], &nodeItem{
|
||||
node: newNode(ctx, n, p.conf, p.settings),
|
||||
weight: n.Weight,
|
||||
current: 0,
|
||||
})
|
||||
}
|
||||
} else if found {
|
||||
// Capability changed, remove the old node.
|
||||
p.nodes[capability] = append(p.nodes[capability][:index], p.nodes[capability][index+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type slaveDummyNodePool struct {
|
||||
conf conf.ConfigProvider
|
||||
settings setting.Provider
|
||||
masterNode Node
|
||||
}
|
||||
|
||||
func NewSlaveDummyNodePool(ctx context.Context, config conf.ConfigProvider, settings setting.Provider) NodePool {
|
||||
return &slaveDummyNodePool{
|
||||
conf: config,
|
||||
settings: settings,
|
||||
masterNode: newNode(ctx, &ent.Node{
|
||||
ID: 0,
|
||||
Name: "Master",
|
||||
Type: node.TypeMaster,
|
||||
}, config, settings),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slaveDummyNodePool) Upsert(ctx context.Context, node *ent.Node) {
|
||||
}
|
||||
|
||||
func (s *slaveDummyNodePool) Get(ctx context.Context, capability types.NodeCapability, preferred int) (Node, error) {
|
||||
return s.masterNode, nil
|
||||
}
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestInitFailed(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
|
||||
Init()
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestInitSuccess(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType))
|
||||
Init()
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestNodePool_GetNodeByID(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
mockNode := &nodeMock{}
|
||||
|
||||
// inactive
|
||||
{
|
||||
p.inactive[1] = mockNode
|
||||
a.Equal(mockNode, p.GetNodeByID(1))
|
||||
}
|
||||
|
||||
// active
|
||||
{
|
||||
delete(p.inactive, 1)
|
||||
p.active[1] = mockNode
|
||||
a.Equal(mockNode, p.GetNodeByID(1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodePool_NodeStatusChange(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
n := &MasterNode{Model: &model.Node{}}
|
||||
p.Init()
|
||||
p.inactive[1] = n
|
||||
|
||||
p.nodeStatusChange(true, 1)
|
||||
a.Len(p.inactive, 0)
|
||||
a.Equal(n, p.active[1])
|
||||
|
||||
p.nodeStatusChange(false, 1)
|
||||
a.Len(p.active, 0)
|
||||
a.Equal(n, p.inactive[1])
|
||||
|
||||
p.nodeStatusChange(false, 1)
|
||||
a.Len(p.active, 0)
|
||||
a.Equal(n, p.inactive[1])
|
||||
}
|
||||
|
||||
func TestNodePool_Add(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
|
||||
// new node
|
||||
{
|
||||
p.Add(&model.Node{})
|
||||
a.Len(p.active, 1)
|
||||
}
|
||||
|
||||
// old node
|
||||
{
|
||||
p.inactive[0] = p.active[0]
|
||||
delete(p.active, 0)
|
||||
p.Add(&model.Node{})
|
||||
a.Len(p.active, 0)
|
||||
a.Len(p.inactive, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodePool_Delete(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
|
||||
// active
|
||||
{
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("Kill")
|
||||
p.active[0] = mockNode
|
||||
p.Delete(0)
|
||||
a.Len(p.active, 0)
|
||||
a.Len(p.inactive, 0)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
p.Init()
|
||||
|
||||
// inactive
|
||||
{
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("Kill")
|
||||
p.inactive[0] = mockNode
|
||||
p.Delete(0)
|
||||
a.Len(p.active, 0)
|
||||
a.Len(p.inactive, 0)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodePool_BalanceNodeByFeature(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
p := &NodePool{}
|
||||
p.Init()
|
||||
|
||||
// success
|
||||
{
|
||||
p.featureMap["test"] = []Node{&MasterNode{}}
|
||||
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
|
||||
a.NoError(err)
|
||||
a.Equal(p.featureMap["test"][0], res)
|
||||
}
|
||||
|
||||
// NoNodes
|
||||
{
|
||||
p.featureMap["test"] = []Node{}
|
||||
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// No match feature
|
||||
{
|
||||
err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin"))
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
}
|
||||
}
|
||||
207
pkg/cluster/routes/routes.go
Normal file
207
pkg/cluster/routes/routes.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
)
|
||||
|
||||
const (
|
||||
IsDownloadQuery = "download"
|
||||
IsThumbQuery = "thumb"
|
||||
SlaveClearTaskRegistryQuery = "deleteOnComplete"
|
||||
)
|
||||
|
||||
var (
|
||||
masterPing *url.URL
|
||||
masterUserActivate *url.URL
|
||||
masterUserReset *url.URL
|
||||
masterHome *url.URL
|
||||
)
|
||||
|
||||
func init() {
|
||||
masterPing, _ = url.Parse(constants.APIPrefix + "/site/ping")
|
||||
masterUserActivate, _ = url.Parse("/session/activate")
|
||||
masterUserReset, _ = url.Parse("/session/reset")
|
||||
}
|
||||
|
||||
func FrontendHomeUrl(base *url.URL, path string) *url.URL {
|
||||
route, _ := url.Parse(fmt.Sprintf("/home"))
|
||||
q := route.Query()
|
||||
q.Set("path", path)
|
||||
route.RawQuery = q.Encode()
|
||||
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func MasterPingUrl(base *url.URL) *url.URL {
|
||||
return base.ResolveReference(masterPing)
|
||||
}
|
||||
|
||||
func MasterSlaveCallbackUrl(base *url.URL, driver, id, secret string) *url.URL {
|
||||
apiBaseURI, _ := url.Parse(path.Join(constants.APIPrefix+"/callback", driver, id, secret))
|
||||
return base.ResolveReference(apiBaseURI)
|
||||
}
|
||||
|
||||
func MasterUserActivateAPIUrl(base *url.URL, uid string) *url.URL {
|
||||
route, _ := url.Parse(constants.APIPrefix + "/user/activate/" + uid)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func MasterUserActivateUrl(base *url.URL) *url.URL {
|
||||
return base.ResolveReference(masterUserActivate)
|
||||
}
|
||||
|
||||
func MasterUserResetUrl(base *url.URL) *url.URL {
|
||||
return base.ResolveReference(masterUserReset)
|
||||
}
|
||||
|
||||
func MasterShareUrl(base *url.URL, id, password string) *url.URL {
|
||||
p := "/s/" + id
|
||||
if password != "" {
|
||||
p += ("/" + password)
|
||||
}
|
||||
route, _ := url.Parse(p)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func MasterDirectLink(base *url.URL, id, name string) *url.URL {
|
||||
p := path.Join("/f", id, url.PathEscape(name))
|
||||
route, _ := url.Parse(p)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
// MasterShareLongUrl generates a long share URL for redirect.
|
||||
func MasterShareLongUrl(id, password string) *url.URL {
|
||||
base, _ := url.Parse("/home")
|
||||
q := base.Query()
|
||||
|
||||
q.Set("path", fs.NewShareUri(id, password))
|
||||
base.RawQuery = q.Encode()
|
||||
return base
|
||||
}
|
||||
|
||||
func MasterArchiveDownloadUrl(base *url.URL, sessionID string) *url.URL {
|
||||
routes, err := url.Parse(path.Join(constants.APIPrefix, "file", "archive", sessionID, "archive.zip"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return base.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func MasterPolicyOAuthCallback(base *url.URL) *url.URL {
|
||||
if base.Scheme != "https" {
|
||||
base.Scheme = "https"
|
||||
}
|
||||
routes, err := url.Parse("/admin/policy/oauth")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return base.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func MasterGetCredentialUrl(base, key string) *url.URL {
|
||||
masterBase, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "credential", key))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return masterBase.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func MasterStatelessUrl(base, method string) *url.URL {
|
||||
masterBase, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
routes, err := url.Parse(path.Join(constants.APIPrefixSlave, "statelessUpload", method))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return masterBase.ResolveReference(routes)
|
||||
}
|
||||
|
||||
func SlaveUploadUrl(base *url.URL, sessionID string) *url.URL {
|
||||
base.Path = path.Join(base.Path, constants.APIPrefixSlave, "/upload", sessionID)
|
||||
return base
|
||||
}
|
||||
|
||||
func MasterFileContentUrl(base *url.URL, entityId, name string, download, thumb bool, speed int64) *url.URL {
|
||||
name = url.PathEscape(name)
|
||||
|
||||
route, _ := url.Parse(constants.APIPrefix + fmt.Sprintf("/file/content/%s/%d/%s", entityId, speed, name))
|
||||
if base != nil {
|
||||
route = base.ResolveReference(route)
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
if download {
|
||||
values.Set(IsDownloadQuery, "true")
|
||||
}
|
||||
|
||||
if thumb {
|
||||
values.Set(IsThumbQuery, "true")
|
||||
}
|
||||
|
||||
route.RawQuery = values.Encode()
|
||||
return route
|
||||
}
|
||||
|
||||
func MasterWopiSrc(base *url.URL, sessionId string) *url.URL {
|
||||
route, _ := url.Parse(constants.APIPrefix + "/file/wopi/" + sessionId)
|
||||
return base.ResolveReference(route)
|
||||
}
|
||||
|
||||
func SlaveFileContentUrl(base *url.URL, srcPath, name string, download bool, speed int64, nodeId int) *url.URL {
|
||||
srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath)))
|
||||
name = url.PathEscape(name)
|
||||
route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/content/%d/%s/%d/%s", nodeId, srcPath, speed, name))
|
||||
base = base.ResolveReference(route)
|
||||
|
||||
values := url.Values{}
|
||||
if download {
|
||||
values.Set(IsDownloadQuery, "true")
|
||||
}
|
||||
|
||||
base.RawQuery = values.Encode()
|
||||
return base
|
||||
}
|
||||
|
||||
func SlaveMediaMetaRoute(src, ext string) string {
|
||||
src = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(src)))
|
||||
return fmt.Sprintf("file/meta/%s/%s", src, url.PathEscape(ext))
|
||||
}
|
||||
|
||||
func SlaveThumbUrl(base *url.URL, srcPath, ext string) *url.URL {
|
||||
srcPath = url.PathEscape(base64.URLEncoding.EncodeToString([]byte(srcPath)))
|
||||
ext = url.PathEscape(ext)
|
||||
route, _ := url.Parse(constants.APIPrefixSlave + fmt.Sprintf("/file/thumb/%s/%s", srcPath, ext))
|
||||
base = base.ResolveReference(route)
|
||||
return base
|
||||
}
|
||||
|
||||
func SlaveGetTaskRoute(id int, deleteOnComplete bool) string {
|
||||
p := constants.APIPrefixSlave + "/task/" + strconv.Itoa(id)
|
||||
if deleteOnComplete {
|
||||
p += "?" + SlaveClearTaskRegistryQuery + "=true"
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func SlavePingRoute(base *url.URL) string {
|
||||
route, _ := url.Parse(constants.APIPrefixSlave + "/ping")
|
||||
return base.ResolveReference(route).String()
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SlaveNode struct {
|
||||
Model *model.Node
|
||||
Active bool
|
||||
|
||||
caller slaveCaller
|
||||
callback func(bool, uint)
|
||||
close chan bool
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type slaveCaller struct {
|
||||
parent *SlaveNode
|
||||
Client request.Client
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *SlaveNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
|
||||
// Init http request client
|
||||
var endpoint *url.URL
|
||||
if serverURL, err := url.Parse(node.Model.Server); err == nil {
|
||||
var controller *url.URL
|
||||
controller, _ = url.Parse("/api/v3/slave/")
|
||||
endpoint = serverURL.ResolveReference(controller)
|
||||
}
|
||||
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
node.caller.Client = request.NewClient(
|
||||
request.WithMasterMeta(),
|
||||
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
|
||||
request.WithEndpoint(endpoint.String()),
|
||||
)
|
||||
|
||||
node.caller.parent = node
|
||||
if node.close != nil {
|
||||
node.lock.Unlock()
|
||||
node.close <- true
|
||||
go node.StartPingLoop()
|
||||
} else {
|
||||
node.Active = true
|
||||
node.lock.Unlock()
|
||||
go node.StartPingLoop()
|
||||
}
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
|
||||
node.lock.Lock()
|
||||
node.callback = callback
|
||||
node.lock.Unlock()
|
||||
}
|
||||
|
||||
// Ping 从机节点,返回从机负载
|
||||
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
reqBodyEncoded, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
|
||||
resp, err := node.caller.Client.Request(
|
||||
"POST",
|
||||
"heartbeat",
|
||||
bodyReader,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
var res serializer.NodePingResp
|
||||
|
||||
if resStr, ok := resp.Data.(string); ok {
|
||||
err = json.Unmarshal([]byte(resStr), &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *SlaveNode) IsActive() bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Active
|
||||
}
|
||||
|
||||
// Kill 结束节点内相关循环
|
||||
func (node *SlaveNode) Kill() {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if node.close != nil {
|
||||
close(node.close)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取从机Aria2实例
|
||||
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
return &node.caller
|
||||
}
|
||||
|
||||
func (node *SlaveNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *SlaveNode) StartPingLoop() {
|
||||
node.lock.Lock()
|
||||
node.close = make(chan bool)
|
||||
node.lock.Unlock()
|
||||
|
||||
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
|
||||
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
|
||||
pingTicker := time.Duration(0)
|
||||
|
||||
util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
|
||||
retry := 0
|
||||
recoverMode := false
|
||||
isFirstLoop := true
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-time.After(pingTicker):
|
||||
if pingTicker == 0 {
|
||||
pingTicker = tickDuration
|
||||
}
|
||||
|
||||
util.Log().Debug("Slave node %q send ping.", node.Model.Name)
|
||||
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
|
||||
isFirstLoop = false
|
||||
|
||||
if err != nil {
|
||||
util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
|
||||
retry++
|
||||
if retry >= model.GetIntSetting("slave_node_retry", 3) {
|
||||
util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
|
||||
node.changeStatus(false)
|
||||
|
||||
if !recoverMode {
|
||||
// 启动恢复监控循环
|
||||
util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
|
||||
pingTicker = recoverDuration
|
||||
recoverMode = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if recoverMode {
|
||||
util.Log().Debug("Slave node %q recovered.", node.Model.Name)
|
||||
pingTicker = tickDuration
|
||||
recoverMode = false
|
||||
isFirstLoop = true
|
||||
}
|
||||
|
||||
util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
|
||||
node.changeStatus(true)
|
||||
retry = 0
|
||||
}
|
||||
|
||||
case <-node.close:
|
||||
util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) IsMater() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||
return &serializer.NodePingReq{
|
||||
SiteURL: model.GetSiteURL().String(),
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) changeStatus(isActive bool) {
|
||||
node.lock.RLock()
|
||||
id := node.Model.ID
|
||||
if isActive != node.Active {
|
||||
node.lock.RUnlock()
|
||||
node.lock.Lock()
|
||||
node.Active = isActive
|
||||
node.lock.Unlock()
|
||||
node.callback(isActive, id)
|
||||
} else {
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendAria2Call send remote aria2 call to slave node
|
||||
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
|
||||
reqReader, err := getAria2RequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Client.Request(
|
||||
"POST",
|
||||
"aria2/"+scope,
|
||||
reqReader,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
}
|
||||
|
||||
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
GroupOptions: options,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "task")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "status")
|
||||
if err != nil {
|
||||
return rpc.StatusInfo{}, err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
var status rpc.StatusInfo
|
||||
res.GobDecode(&status)
|
||||
|
||||
return status, err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Cancel(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "cancel")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Select(task *model.Download, files []int) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
Files: files,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "select")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) GetConfig() model.Aria2Option {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
return s.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "delete")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
|
||||
reqBodyEncoded, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return strings.NewReader(string(reqBodyEncoded)), nil
|
||||
}
|
||||
|
||||
// RemoteCallback 发送远程存储策略上传回调请求
|
||||
func RemoteCallback(url string, body serializer.UploadCallback) error {
|
||||
callbackBody, err := json.Marshal(struct {
|
||||
Data serializer.UploadCallback `json:"data"`
|
||||
}{
|
||||
Data: body,
|
||||
})
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
|
||||
}
|
||||
|
||||
resp := request.GeneralClient.Request(
|
||||
"POST",
|
||||
url,
|
||||
bytes.NewReader(callbackBody),
|
||||
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
|
||||
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
|
||||
)
|
||||
|
||||
if resp.Err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
|
||||
}
|
||||
|
||||
// 解析回调服务端响应
|
||||
response, err := resp.DecodeResponse()
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
|
||||
return serializer.NewError(serializer.CodeCallbackError, msg, err)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,559 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSlaveNode_InitAndKill(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
n := &SlaveNode{
|
||||
callback: func(b bool, u uint) {
|
||||
|
||||
},
|
||||
}
|
||||
|
||||
a.NotPanics(func() {
|
||||
n.Init(&model.Node{})
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
n.Init(&model.Node{})
|
||||
n.Kill()
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlaveNode_DummyMethods(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Model.ID = 5
|
||||
a.Equal(m.Model.ID, m.ID())
|
||||
a.Equal(m.Model.ID, m.DBModel().ID)
|
||||
|
||||
a.False(m.IsActive())
|
||||
a.False(m.IsMater())
|
||||
|
||||
m.SubscribeStatusChange(func(isActive bool, id uint) {})
|
||||
}
|
||||
|
||||
func TestSlaveNode_IsFeatureEnabled(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.False(m.IsFeatureEnabled("aria2"))
|
||||
a.False(m.IsFeatureEnabled("random"))
|
||||
m.Model.Aria2Enabled = true
|
||||
a.True(m.IsFeatureEnabled("aria2"))
|
||||
}
|
||||
|
||||
func TestSlaveNode_Ping(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
// master return error code
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
a.Equal(1, err.(serializer.AppError).Code)
|
||||
}
|
||||
|
||||
// return unexpected json
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.Error(err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// return success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveNode_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.Model.Aria2Enabled = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
}
|
||||
|
||||
func TestSlaveNode_StartPingLoop(t *testing.T) {
|
||||
callbackCount := 0
|
||||
finishedChan := make(chan struct{})
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m := &SlaveNode{
|
||||
Active: true,
|
||||
Model: &model.Node{},
|
||||
callback: func(b bool, u uint) {
|
||||
callbackCount++
|
||||
if callbackCount == 2 {
|
||||
close(finishedChan)
|
||||
}
|
||||
if callbackCount == 1 {
|
||||
mockRequest.AssertExpectations(t)
|
||||
mockRequest = requestMock{}
|
||||
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
|
||||
},
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
cache.Set("setting_slave_ping_interval", "0", 0)
|
||||
cache.Set("setting_slave_recover_interval", "0", 0)
|
||||
cache.Set("setting_slave_node_retry", "1", 0)
|
||||
|
||||
m.caller.Client = &mockRequest
|
||||
go func() {
|
||||
select {
|
||||
case <-finishedChan:
|
||||
m.Kill()
|
||||
}
|
||||
}()
|
||||
m.StartPingLoop()
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestSlaveNode_AuthInstance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.MasterAuthInstance())
|
||||
a.NotNil(m.SlaveAuthInstance())
|
||||
}
|
||||
|
||||
func TestSlaveNode_ChangeStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
isActive := false
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
callback: func(b bool, u uint) {
|
||||
isActive = b
|
||||
},
|
||||
}
|
||||
|
||||
a.NotPanics(func() {
|
||||
m.changeStatus(false)
|
||||
})
|
||||
m.changeStatus(true)
|
||||
a.True(isActive)
|
||||
}
|
||||
|
||||
func getTestRPCNodeSlave() *SlaveNode {
|
||||
m := &SlaveNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
m.caller.parent = m
|
||||
return m
|
||||
}
|
||||
|
||||
func TestSlaveCaller_CreateTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.CreateTask(&model.Download{}, nil)
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.CreateTask(&model.Download{}, nil)
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.CreateTask(&model.Download{}, nil)
|
||||
a.Equal("res", res)
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_Status(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.Status(&model.Download{})
|
||||
a.Empty(res.Status)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.Status(&model.Download{})
|
||||
a.Empty(res.Status)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
res, err := m.caller.Status(&model.Download{})
|
||||
a.Empty(res.Status)
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_Cancel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Cancel(&model.Download{})
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_Select(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
m.caller.Init()
|
||||
m.caller.GetConfig()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Select(&model.Download{}, nil)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Select(&model.Download{}, nil)
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.Select(&model.Download{}, nil)
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveCaller_DeleteTempFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNodeSlave()
|
||||
m.caller.Init()
|
||||
m.caller.GetConfig()
|
||||
|
||||
// master return 404
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
// master return success
|
||||
{
|
||||
mockRequest := requestMock{}
|
||||
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
|
||||
},
|
||||
})
|
||||
m.caller.Client = mockRequest
|
||||
err := m.caller.DeleteTempFile(&model.Download{})
|
||||
a.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteCallback(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
// 回调成功
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.NoError(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 服务端返回业务错误
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.EqualValues(401, resp.(serializer.AppError).Code)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 无法解析回调响应
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// HTTP状态码非200
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 404,
|
||||
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
|
||||
},
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 无法发起回调
|
||||
{
|
||||
clientMock := requestmock.RequestMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"POST",
|
||||
"http://test/test/url",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: errors.New("error"),
|
||||
})
|
||||
request.GeneralClient = clientMock
|
||||
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
|
||||
asserts.Error(resp)
|
||||
clientMock.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
255
pkg/conf/conf.go
255
pkg/conf/conf.go
@@ -1,75 +1,135 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/go-ini/ini"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// database 数据库
|
||||
type database struct {
|
||||
Type string
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Name string
|
||||
TablePrefix string
|
||||
DBFile string
|
||||
Port int
|
||||
Charset string
|
||||
UnixSocket bool
|
||||
const (
|
||||
envConfOverrideKey = "CR_CONF_"
|
||||
)
|
||||
|
||||
type ConfigProvider interface {
|
||||
Database() *Database
|
||||
System() *System
|
||||
SSL() *SSL
|
||||
Unix() *Unix
|
||||
Slave() *Slave
|
||||
Redis() *Redis
|
||||
Cors() *Cors
|
||||
OptionOverwrite() map[string]any
|
||||
}
|
||||
|
||||
// system 系统通用配置
|
||||
type system struct {
|
||||
Mode string `validate:"eq=master|eq=slave"`
|
||||
Listen string `validate:"required"`
|
||||
Debug bool
|
||||
SessionSecret string
|
||||
HashIDSalt string
|
||||
GracePeriod int `validate:"gte=0"`
|
||||
ProxyHeader string `validate:"required_with=Listen"`
|
||||
// NewIniConfigProvider initializes a new Ini config file provider. A default config file
|
||||
// will be created if the given path does not exist.
|
||||
func NewIniConfigProvider(configPath string, l logging.Logger) (ConfigProvider, error) {
|
||||
if configPath == "" || !util.Exists(configPath) {
|
||||
l.Info("Config file %q not found, creating a new one.", configPath)
|
||||
// 创建初始配置文件
|
||||
confContent := util.Replace(map[string]string{
|
||||
"{SessionSecret}": util.RandStringRunes(64),
|
||||
}, defaultConf)
|
||||
f, err := util.CreatNestedFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create config file: %w", err)
|
||||
}
|
||||
|
||||
// 写入配置文件
|
||||
_, err = f.WriteString(confContent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
|
||||
cfg, err := ini.Load(configPath, []byte(getOverrideConfFromEnv(l)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file %q: %w", configPath, err)
|
||||
}
|
||||
|
||||
provider := &iniConfigProvider{
|
||||
database: *DatabaseConfig,
|
||||
system: *SystemConfig,
|
||||
ssl: *SSLConfig,
|
||||
unix: *UnixConfig,
|
||||
slave: *SlaveConfig,
|
||||
redis: *RedisConfig,
|
||||
cors: *CORSConfig,
|
||||
optionOverwrite: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
sections := map[string]interface{}{
|
||||
"Database": &provider.database,
|
||||
"System": &provider.system,
|
||||
"SSL": &provider.ssl,
|
||||
"UnixSocket": &provider.unix,
|
||||
"Redis": &provider.redis,
|
||||
"CORS": &provider.cors,
|
||||
"Slave": &provider.slave,
|
||||
}
|
||||
for sectionName, sectionStruct := range sections {
|
||||
err = mapSection(cfg, sectionName, sectionStruct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config section %q: %w", sectionName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 映射数据库配置覆盖
|
||||
for _, key := range cfg.Section("OptionOverwrite").Keys() {
|
||||
provider.optionOverwrite[key.Name()] = key.Value()
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
type ssl struct {
|
||||
CertPath string `validate:"omitempty,required"`
|
||||
KeyPath string `validate:"omitempty,required"`
|
||||
Listen string `validate:"required"`
|
||||
type iniConfigProvider struct {
|
||||
database Database
|
||||
system System
|
||||
ssl SSL
|
||||
unix Unix
|
||||
slave Slave
|
||||
redis Redis
|
||||
cors Cors
|
||||
optionOverwrite map[string]any
|
||||
}
|
||||
|
||||
type unix struct {
|
||||
Listen string
|
||||
Perm uint32
|
||||
func (i *iniConfigProvider) Database() *Database {
|
||||
return &i.database
|
||||
}
|
||||
|
||||
// slave 作为slave存储端配置
|
||||
type slave struct {
|
||||
Secret string `validate:"omitempty,gte=64"`
|
||||
CallbackTimeout int `validate:"omitempty,gte=1"`
|
||||
SignatureTTL int `validate:"omitempty,gte=1"`
|
||||
func (i *iniConfigProvider) System() *System {
|
||||
return &i.system
|
||||
}
|
||||
|
||||
// redis 配置
|
||||
type redis struct {
|
||||
Network string
|
||||
Server string
|
||||
User string
|
||||
Password string
|
||||
DB string
|
||||
func (i *iniConfigProvider) SSL() *SSL {
|
||||
return &i.ssl
|
||||
}
|
||||
|
||||
// 跨域配置
|
||||
type cors struct {
|
||||
AllowOrigins []string
|
||||
AllowMethods []string
|
||||
AllowHeaders []string
|
||||
AllowCredentials bool
|
||||
ExposeHeaders []string
|
||||
SameSite string
|
||||
Secure bool
|
||||
func (i *iniConfigProvider) Unix() *Unix {
|
||||
return &i.unix
|
||||
}
|
||||
|
||||
var cfg *ini.File
|
||||
func (i *iniConfigProvider) Slave() *Slave {
|
||||
return &i.slave
|
||||
}
|
||||
|
||||
func (i *iniConfigProvider) Redis() *Redis {
|
||||
return &i.redis
|
||||
}
|
||||
|
||||
func (i *iniConfigProvider) Cors() *Cors {
|
||||
return &i.cors
|
||||
}
|
||||
|
||||
func (i *iniConfigProvider) OptionOverwrite() map[string]any {
|
||||
return i.optionOverwrite
|
||||
}
|
||||
|
||||
const defaultConf = `[System]
|
||||
Debug = false
|
||||
@@ -79,67 +139,8 @@ SessionSecret = {SessionSecret}
|
||||
HashIDSalt = {HashIDSalt}
|
||||
`
|
||||
|
||||
// Init 初始化配置文件
|
||||
func Init(path string) {
|
||||
var err error
|
||||
|
||||
if path == "" || !util.Exists(path) {
|
||||
// 创建初始配置文件
|
||||
confContent := util.Replace(map[string]string{
|
||||
"{SessionSecret}": util.RandStringRunes(64),
|
||||
"{HashIDSalt}": util.RandStringRunes(64),
|
||||
}, defaultConf)
|
||||
f, err := util.CreatNestedFile(path)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to create config file: %s", err)
|
||||
}
|
||||
|
||||
// 写入配置文件
|
||||
_, err = f.WriteString(confContent)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to write config file: %s", err)
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
|
||||
cfg, err = ini.Load(path)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to parse config file %q: %s", path, err)
|
||||
}
|
||||
|
||||
sections := map[string]interface{}{
|
||||
"Database": DatabaseConfig,
|
||||
"System": SystemConfig,
|
||||
"SSL": SSLConfig,
|
||||
"UnixSocket": UnixConfig,
|
||||
"Redis": RedisConfig,
|
||||
"CORS": CORSConfig,
|
||||
"Slave": SlaveConfig,
|
||||
}
|
||||
for sectionName, sectionStruct := range sections {
|
||||
err = mapSection(sectionName, sectionStruct)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to parse config section %q: %s", sectionName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 映射数据库配置覆盖
|
||||
for _, key := range cfg.Section("OptionOverwrite").Keys() {
|
||||
OptionOverwrite[key.Name()] = key.Value()
|
||||
}
|
||||
|
||||
// 重设log等级
|
||||
if !SystemConfig.Debug {
|
||||
util.Level = util.LevelInformational
|
||||
util.GloablLogger = nil
|
||||
util.Log()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// mapSection 将配置文件的 Section 映射到结构体上
|
||||
func mapSection(section string, confStruct interface{}) error {
|
||||
func mapSection(cfg *ini.File, section string, confStruct interface{}) error {
|
||||
err := cfg.Section(section).MapTo(confStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -154,3 +155,35 @@ func mapSection(section string, confStruct interface{}) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOverrideConfFromEnv(l logging.Logger) string {
|
||||
confMaps := make(map[string]map[string]string)
|
||||
for _, env := range os.Environ() {
|
||||
if !strings.HasPrefix(env, envConfOverrideKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
// split by key=value and get key
|
||||
kv := strings.SplitN(env, "=", 2)
|
||||
configKey := strings.TrimPrefix(kv[0], envConfOverrideKey)
|
||||
configValue := kv[1]
|
||||
sectionKey := strings.SplitN(configKey, ".", 2)
|
||||
if confMaps[sectionKey[0]] == nil {
|
||||
confMaps[sectionKey[0]] = make(map[string]string)
|
||||
}
|
||||
|
||||
confMaps[sectionKey[0]][sectionKey[1]] = configValue
|
||||
l.Info("Override config %q = %q", configKey, configValue)
|
||||
}
|
||||
|
||||
// generate ini content
|
||||
var sb strings.Builder
|
||||
for section, kvs := range confMaps {
|
||||
sb.WriteString(fmt.Sprintf("[%s]\n", section))
|
||||
for k, v := range kvs {
|
||||
sb.WriteString(fmt.Sprintf("%s = %s\n", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// 测试Init日志路径错误
|
||||
@@ -15,10 +14,10 @@ func TestInitPanic(t *testing.T) {
|
||||
|
||||
// 日志路径不存在时
|
||||
asserts.NotPanics(func() {
|
||||
Init("not/exist/path/conf.ini")
|
||||
Init("not/exist/path")
|
||||
})
|
||||
|
||||
asserts.True(util.Exists("not/exist/path/conf.ini"))
|
||||
asserts.True(util.Exists("conf.ini"))
|
||||
|
||||
}
|
||||
|
||||
@@ -56,11 +55,7 @@ User = root
|
||||
Password = root
|
||||
Host = 127.0.0.1:3306
|
||||
Name = v3
|
||||
TablePrefix = v3_
|
||||
|
||||
[OptionOverwrite]
|
||||
key=value
|
||||
`
|
||||
TablePrefix = v3_`
|
||||
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
|
||||
defer func() { err = os.Remove("testConf.ini") }()
|
||||
if err != nil {
|
||||
@@ -69,7 +64,6 @@ key=value
|
||||
asserts.NotPanics(func() {
|
||||
Init("testConf.ini")
|
||||
})
|
||||
asserts.Equal(OptionOverwrite["key"], "value")
|
||||
}
|
||||
|
||||
func TestMapSection(t *testing.T) {
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package conf
|
||||
|
||||
// RedisConfig Redis服务器配置
|
||||
var RedisConfig = &redis{
|
||||
Network: "tcp",
|
||||
Server: "",
|
||||
Password: "",
|
||||
DB: "0",
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
var DatabaseConfig = &database{
|
||||
Type: "UNSET",
|
||||
Charset: "utf8",
|
||||
DBFile: "cloudreve.db",
|
||||
Port: 3306,
|
||||
UnixSocket: false,
|
||||
}
|
||||
|
||||
// SystemConfig 系统公用配置
|
||||
var SystemConfig = &system{
|
||||
Debug: false,
|
||||
Mode: "master",
|
||||
Listen: ":5212",
|
||||
ProxyHeader: "X-Forwarded-For",
|
||||
}
|
||||
|
||||
// CORSConfig 跨域配置
|
||||
var CORSConfig = &cors{
|
||||
AllowOrigins: []string{"UNSET"},
|
||||
AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"},
|
||||
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"},
|
||||
AllowCredentials: false,
|
||||
ExposeHeaders: nil,
|
||||
SameSite: "Default",
|
||||
Secure: false,
|
||||
}
|
||||
|
||||
// SlaveConfig 从机配置
|
||||
var SlaveConfig = &slave{
|
||||
CallbackTimeout: 20,
|
||||
SignatureTTL: 60,
|
||||
}
|
||||
|
||||
var SSLConfig = &ssl{
|
||||
Listen: ":443",
|
||||
CertPath: "",
|
||||
KeyPath: "",
|
||||
}
|
||||
|
||||
var UnixConfig = &unix{
|
||||
Listen: "",
|
||||
}
|
||||
|
||||
var OptionOverwrite = map[string]interface{}{}
|
||||
138
pkg/conf/types.go
Normal file
138
pkg/conf/types.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package conf
|
||||
|
||||
import "github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
|
||||
type DBType string
|
||||
|
||||
var (
|
||||
SQLiteDB DBType = "sqlite"
|
||||
SQLite3DB DBType = "sqlite3"
|
||||
MySqlDB DBType = "mysql"
|
||||
MsSqlDB DBType = "mssql"
|
||||
PostgresDB DBType = "postgres"
|
||||
)
|
||||
|
||||
// Database 数据库
|
||||
type Database struct {
|
||||
Type DBType
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Name string
|
||||
TablePrefix string
|
||||
DBFile string
|
||||
Port int
|
||||
Charset string
|
||||
UnixSocket bool
|
||||
}
|
||||
|
||||
type SysMode string
|
||||
|
||||
var (
|
||||
MasterMode SysMode = "master"
|
||||
SlaveMode SysMode = "slave"
|
||||
)
|
||||
|
||||
// System 系统通用配置
|
||||
type System struct {
|
||||
Mode SysMode `validate:"eq=master|eq=slave"`
|
||||
Listen string `validate:"required"`
|
||||
Debug bool
|
||||
SessionSecret string
|
||||
HashIDSalt string // deprecated
|
||||
GracePeriod int `validate:"gte=0"`
|
||||
ProxyHeader string `validate:"required_with=Listen"`
|
||||
LogLevel string `validate:"oneof=debug info warning error"`
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
CertPath string `validate:"omitempty,required"`
|
||||
KeyPath string `validate:"omitempty,required"`
|
||||
Listen string `validate:"required"`
|
||||
}
|
||||
|
||||
type Unix struct {
|
||||
Listen string
|
||||
Perm uint32
|
||||
}
|
||||
|
||||
// Slave 作为slave存储端配置
|
||||
type Slave struct {
|
||||
Secret string `validate:"omitempty,gte=64"`
|
||||
CallbackTimeout int `validate:"omitempty,gte=1"`
|
||||
SignatureTTL int `validate:"omitempty,gte=1"`
|
||||
}
|
||||
|
||||
// Redis 配置
|
||||
type Redis struct {
|
||||
Network string
|
||||
Server string
|
||||
User string
|
||||
Password string
|
||||
DB string
|
||||
}
|
||||
|
||||
// 跨域配置
|
||||
type Cors struct {
|
||||
AllowOrigins []string
|
||||
AllowMethods []string
|
||||
AllowHeaders []string
|
||||
AllowCredentials bool
|
||||
ExposeHeaders []string
|
||||
SameSite string
|
||||
Secure bool
|
||||
}
|
||||
|
||||
// RedisConfig Redis服务器配置
|
||||
var RedisConfig = &Redis{
|
||||
Network: "tcp",
|
||||
Server: "",
|
||||
Password: "",
|
||||
DB: "0",
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
var DatabaseConfig = &Database{
|
||||
Charset: "utf8mb4",
|
||||
DBFile: util.DataPath("cloudreve.db"),
|
||||
Port: 3306,
|
||||
UnixSocket: false,
|
||||
}
|
||||
|
||||
// SystemConfig 系统公用配置
|
||||
var SystemConfig = &System{
|
||||
Debug: false,
|
||||
Mode: MasterMode,
|
||||
Listen: ":5212",
|
||||
ProxyHeader: "X-Forwarded-For",
|
||||
LogLevel: "info",
|
||||
}
|
||||
|
||||
// CORSConfig 跨域配置
|
||||
var CORSConfig = &Cors{
|
||||
AllowOrigins: []string{"UNSET"},
|
||||
AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"},
|
||||
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"},
|
||||
AllowCredentials: false,
|
||||
ExposeHeaders: nil,
|
||||
SameSite: "Default",
|
||||
Secure: false,
|
||||
}
|
||||
|
||||
// SlaveConfig 从机配置
|
||||
var SlaveConfig = &Slave{
|
||||
CallbackTimeout: 20,
|
||||
SignatureTTL: 600,
|
||||
}
|
||||
|
||||
var SSLConfig = &SSL{
|
||||
Listen: ":443",
|
||||
CertPath: "",
|
||||
KeyPath: "",
|
||||
}
|
||||
|
||||
var UnixConfig = &Unix{
|
||||
Listen: "",
|
||||
}
|
||||
|
||||
var OptionOverwrite = map[string]interface{}{}
|
||||
@@ -1,16 +0,0 @@
|
||||
package conf
|
||||
|
||||
// BackendVersion 当前后端版本号
|
||||
var BackendVersion = "3.8.3"
|
||||
|
||||
// RequiredDBVersion 与当前版本匹配的数据库版本
|
||||
var RequiredDBVersion = "3.8.1"
|
||||
|
||||
// RequiredStaticVersion 与当前版本匹配的静态资源版本
|
||||
var RequiredStaticVersion = "3.8.3"
|
||||
|
||||
// IsPro 是否为Pro版本
|
||||
var IsPro = "false"
|
||||
|
||||
// LastCommit 最后commit id
|
||||
var LastCommit = "a11f819"
|
||||
238
pkg/credmanager/credmanager.go
Normal file
238
pkg/credmanager/credmanager.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package credmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
)
|
||||
|
||||
type (
|
||||
// CredManager is a centralized for all Oauth tokens that requires periodic refresh
|
||||
// It is primarily used by OneDrive storage policy.
|
||||
CredManager interface {
|
||||
// Obtain gets a credential from the manager, refresh it if it's expired
|
||||
Obtain(ctx context.Context, key string) (Credential, error)
|
||||
// Upsert inserts or updates a credential in the manager
|
||||
Upsert(ctx context.Context, cred ...Credential) error
|
||||
RefreshAll(ctx context.Context)
|
||||
}
|
||||
|
||||
Credential interface {
|
||||
String() string
|
||||
Refresh(ctx context.Context) (Credential, error)
|
||||
Key() string
|
||||
Expiry() time.Time
|
||||
RefreshedAt() *time.Time
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(CredentialResponse{})
|
||||
}
|
||||
|
||||
func New(kv cache.Driver) CredManager {
|
||||
return &credManager{
|
||||
kv: kv,
|
||||
locks: make(map[string]*sync.Mutex),
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
credManager struct {
|
||||
kv cache.Driver
|
||||
mu sync.RWMutex
|
||||
|
||||
locks map[string]*sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("credential not found")
|
||||
)
|
||||
|
||||
func (m *credManager) Upsert(ctx context.Context, cred ...Credential) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
l := logging.FromContext(ctx)
|
||||
for _, c := range cred {
|
||||
l.Info("CredManager: Upsert credential for key %q...", c.Key())
|
||||
if err := m.kv.Set(c.Key(), c, 0); err != nil {
|
||||
return fmt.Errorf("failed to update credential in KV for key %q: %w", c.Key(), err)
|
||||
}
|
||||
|
||||
if _, ok := m.locks[c.Key()]; !ok {
|
||||
m.locks[c.Key()] = &sync.Mutex{}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *credManager) Obtain(ctx context.Context, key string) (Credential, error) {
|
||||
m.mu.RLock()
|
||||
itemRaw, ok := m.kv.Get(key)
|
||||
if !ok {
|
||||
m.mu.RUnlock()
|
||||
return nil, fmt.Errorf("credential not found for key %q: %w", key, ErrNotFound)
|
||||
}
|
||||
|
||||
l := logging.FromContext(ctx)
|
||||
|
||||
item := itemRaw.(Credential)
|
||||
if _, ok := m.locks[key]; !ok {
|
||||
m.locks[key] = &sync.Mutex{}
|
||||
}
|
||||
m.locks[key].Lock()
|
||||
defer m.locks[key].Unlock()
|
||||
m.mu.RUnlock()
|
||||
|
||||
if item.Expiry().After(time.Now()) {
|
||||
// Credential is still valid
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Credential is expired, refresh it
|
||||
l.Info("Refreshing credential for key %q...", key)
|
||||
newCred, err := item.Refresh(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh credential for key %q: %w", key, err)
|
||||
}
|
||||
|
||||
l.Info("New credential for key %q is obtained, expire at %s", key, newCred.Expiry().String())
|
||||
if err := m.kv.Set(key, newCred, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to update credential in KV for key %q: %w", key, err)
|
||||
}
|
||||
|
||||
return newCred, nil
|
||||
}
|
||||
|
||||
func (m *credManager) RefreshAll(ctx context.Context) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
l := logging.FromContext(ctx)
|
||||
for key := range m.locks {
|
||||
l.Info("Refreshing credential for key %q...", key)
|
||||
m.locks[key].Lock()
|
||||
defer m.locks[key].Unlock()
|
||||
|
||||
itemRaw, ok := m.kv.Get(key)
|
||||
if !ok {
|
||||
l.Warning("Credential not found for key %q", key)
|
||||
continue
|
||||
}
|
||||
|
||||
item := itemRaw.(Credential)
|
||||
newCred, err := item.Refresh(ctx)
|
||||
if err != nil {
|
||||
l.Warning("Failed to refresh credential for key %q: %s", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
l.Info("New credential for key %q is obtained, expire at %s", key, newCred.Expiry().String())
|
||||
if err := m.kv.Set(key, newCred, 0); err != nil {
|
||||
l.Warning("Failed to update credential in KV for key %q: %s", key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
slaveCredManager struct {
|
||||
kv cache.Driver
|
||||
client request.Client
|
||||
}
|
||||
|
||||
CredentialResponse struct {
|
||||
Token string `json:"token"`
|
||||
ExpireAt time.Time `json:"expire_at"`
|
||||
}
|
||||
)
|
||||
|
||||
func NewSlaveManager(kv cache.Driver, config conf.ConfigProvider) CredManager {
|
||||
return &slaveCredManager{
|
||||
kv: kv,
|
||||
client: request.NewClient(
|
||||
config,
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
[]byte(config.Slave().Secret),
|
||||
}, int64(config.Slave().SignatureTTL)),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (c CredentialResponse) String() string {
|
||||
return c.Token
|
||||
}
|
||||
|
||||
func (c CredentialResponse) Refresh(ctx context.Context) (Credential, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c CredentialResponse) Key() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c CredentialResponse) Expiry() time.Time {
|
||||
return c.ExpireAt
|
||||
}
|
||||
|
||||
func (c CredentialResponse) RefreshedAt() *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *slaveCredManager) Upsert(ctx context.Context, cred ...Credential) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *slaveCredManager) Obtain(ctx context.Context, key string) (Credential, error) {
|
||||
itemRaw, ok := m.kv.Get(key)
|
||||
if !ok {
|
||||
return m.requestCredFromMaster(ctx, key)
|
||||
}
|
||||
|
||||
return itemRaw.(Credential), nil
|
||||
}
|
||||
|
||||
// No op on slave node
|
||||
func (m *slaveCredManager) RefreshAll(ctx context.Context) {}
|
||||
|
||||
func (m *slaveCredManager) requestCredFromMaster(ctx context.Context, key string) (Credential, error) {
|
||||
l := logging.FromContext(ctx)
|
||||
l.Info("SlaveCredManager: Requesting credential for key %q from master...", key)
|
||||
|
||||
requestDst := routes.MasterGetCredentialUrl(cluster.MasterSiteUrlFromContext(ctx), key)
|
||||
resp, err := m.client.Request(
|
||||
http.MethodGet,
|
||||
requestDst.String(),
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(l),
|
||||
request.WithSlaveMeta(cluster.NodeIdFromContext(ctx)),
|
||||
request.WithCorrelationID(),
|
||||
).CheckHTTPResponse(http.StatusOK).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request credential from master: %w", err)
|
||||
}
|
||||
|
||||
cred := &CredentialResponse{}
|
||||
resp.GobDecode(&cred)
|
||||
|
||||
if err := m.kv.Set(key, *cred, max(int(time.Until(cred.Expiry()).Seconds()), 1)); err != nil {
|
||||
return nil, fmt.Errorf("failed to update credential in KV for key %q: %w", key, err)
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
func garbageCollect() {
|
||||
// 清理打包下载产生的临时文件
|
||||
collectArchiveFile()
|
||||
|
||||
// 清理过期的内置内存缓存
|
||||
if store, ok := cache.Store.(*cache.MemoStore); ok {
|
||||
collectCache(store)
|
||||
}
|
||||
|
||||
util.Log().Info("Crontab job \"cron_garbage_collect\" complete.")
|
||||
}
|
||||
|
||||
func collectArchiveFile() {
|
||||
// 读取有效期、目录设置
|
||||
tempPath := util.RelativePath(model.GetSettingByName("temp_path"))
|
||||
expires := model.GetIntSetting("download_timeout", 30)
|
||||
|
||||
// 列出文件
|
||||
root := filepath.Join(tempPath, "archive")
|
||||
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
|
||||
if err == nil && !info.IsDir() &&
|
||||
strings.HasPrefix(filepath.Base(path), "archive_") &&
|
||||
time.Now().Sub(info.ModTime()).Seconds() > float64(expires) {
|
||||
util.Log().Debug("Delete expired batch download temp file %q.", path)
|
||||
// 删除符合条件的文件
|
||||
if err := os.Remove(path); err != nil {
|
||||
util.Log().Debug("Failed to delete temp file %q: %s", path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.Log().Debug("Crontab job cannot list temp batch download folder: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func collectCache(store *cache.MemoStore) {
|
||||
util.Log().Debug("Cleanup memory cache.")
|
||||
store.GarbageCollect()
|
||||
}
|
||||
|
||||
func uploadSessionCollect() {
|
||||
placeholders := model.GetUploadPlaceholderFiles(0)
|
||||
|
||||
// 将过期的上传会话按照用户分组
|
||||
userToFiles := make(map[uint][]uint)
|
||||
for _, file := range placeholders {
|
||||
_, sessionExist := cache.Get(filesystem.UploadSessionCachePrefix + *file.UploadSessionID)
|
||||
if sessionExist {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := userToFiles[file.UserID]; !ok {
|
||||
userToFiles[file.UserID] = make([]uint, 0)
|
||||
}
|
||||
|
||||
userToFiles[file.UserID] = append(userToFiles[file.UserID], file.ID)
|
||||
}
|
||||
|
||||
// 删除过期的会话
|
||||
for uid, filesIDs := range userToFiles {
|
||||
user, err := model.GetUserByID(uid)
|
||||
if err != nil {
|
||||
util.Log().Warning("Owner of the upload session cannot be found: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fs, err := filesystem.NewFileSystem(&user)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to initialize filesystem: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = fs.Delete(context.Background(), []uint{}, filesIDs, false, false); err != nil {
|
||||
util.Log().Warning("Failed to delete upload session: %s", err)
|
||||
}
|
||||
|
||||
fs.Recycle()
|
||||
}
|
||||
|
||||
util.Log().Info("Crontab job \"cron_recycle_upload_session\" complete.")
|
||||
}
|
||||
73
pkg/crontab/crontab.go
Normal file
73
pkg/crontab/crontab.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
type (
|
||||
CronTaskFunc func(ctx context.Context)
|
||||
cornRegistration struct {
|
||||
t setting.CronType
|
||||
config string
|
||||
fn CronTaskFunc
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
registrations []cornRegistration
|
||||
)
|
||||
|
||||
// Register registers a cron task.
|
||||
func Register(t setting.CronType, fn CronTaskFunc) {
|
||||
registrations = append(registrations, cornRegistration{
|
||||
t: t,
|
||||
fn: fn,
|
||||
})
|
||||
}
|
||||
|
||||
// NewCron constructs a new cron instance with given dependency.
|
||||
func NewCron(ctx context.Context, dep dependency.Dep) (*cron.Cron, error) {
|
||||
settings := dep.SettingProvider()
|
||||
userClient := dep.UserClient()
|
||||
anonymous, err := userClient.AnonymousUser(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cron: faield to get anonymous user: %w", err)
|
||||
}
|
||||
|
||||
l := dep.Logger()
|
||||
l.Info("Initialize crontab jobs...")
|
||||
c := cron.New()
|
||||
|
||||
for _, r := range registrations {
|
||||
cronConfig := settings.Cron(ctx, r.t)
|
||||
if _, err := c.AddFunc(cronConfig, taskWrapper(string(r.t), cronConfig, anonymous, dep, r.fn)); err != nil {
|
||||
l.Warning("Failed to start crontab job %q: %s", cronConfig, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func taskWrapper(name, config string, user *ent.User, dep dependency.Dep, task CronTaskFunc) func() {
|
||||
l := dep.Logger()
|
||||
l.Info("Cron task %s started with config %q", name, config)
|
||||
return func() {
|
||||
cid := uuid.Must(uuid.NewV4())
|
||||
l.Info("Executing Cron task %q with Cid %q", name, cid)
|
||||
ctx := context.Background()
|
||||
l := dep.Logger().CopyWithPrefix(fmt.Sprintf("[Cid: %s Cron: %s]", cid, name))
|
||||
ctx = dep.ForkWithLogger(ctx, l)
|
||||
ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, cid)
|
||||
ctx = context.WithValue(ctx, logging.LoggerCtx{}, l)
|
||||
ctx = context.WithValue(ctx, inventory.UserCtx{}, user)
|
||||
task(ctx)
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
// Cron 定时任务
|
||||
var Cron *cron.Cron
|
||||
|
||||
// Reload 重新启动定时任务
|
||||
func Reload() {
|
||||
if Cron != nil {
|
||||
Cron.Stop()
|
||||
}
|
||||
Init()
|
||||
}
|
||||
|
||||
// Init 初始化定时任务
|
||||
func Init() {
|
||||
util.Log().Info("Initialize crontab jobs...")
|
||||
// 读取cron日程设置
|
||||
options := model.GetSettingByNames(
|
||||
"cron_garbage_collect",
|
||||
"cron_recycle_upload_session",
|
||||
)
|
||||
Cron := cron.New()
|
||||
for k, v := range options {
|
||||
var handler func()
|
||||
switch k {
|
||||
case "cron_garbage_collect":
|
||||
handler = garbageCollect
|
||||
case "cron_recycle_upload_session":
|
||||
handler = uploadSessionCollect
|
||||
default:
|
||||
util.Log().Warning("Unknown crontab job type %q, skipping...", k)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := Cron.AddFunc(v, handler); err != nil {
|
||||
util.Log().Warning("Failed to start crontab job %q: %s", k, err)
|
||||
}
|
||||
|
||||
}
|
||||
Cron.Start()
|
||||
}
|
||||
283
pkg/downloader/aria2/aria2.go
Normal file
283
pkg/downloader/aria2/aria2.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
Aria2TempFolder = "aria2"
|
||||
deleteTempFileDuration = 120 * time.Second
|
||||
)
|
||||
|
||||
type aria2Client struct {
|
||||
l logging.Logger
|
||||
settings setting.Provider
|
||||
|
||||
options *types.Aria2Setting
|
||||
timeout time.Duration
|
||||
caller rpc.Client
|
||||
}
|
||||
|
||||
func New(l logging.Logger, settings setting.Provider, options *types.Aria2Setting) downloader.Downloader {
|
||||
rpcServer := options.Server
|
||||
rpcUrl, err := url.Parse(options.Server)
|
||||
if err == nil {
|
||||
// add /jsonrpc to the url if not present
|
||||
rpcUrl.Path = "/jsonrpc"
|
||||
rpcServer = rpcUrl.String()
|
||||
}
|
||||
|
||||
options.Server = rpcServer
|
||||
return &aria2Client{
|
||||
l: l,
|
||||
settings: settings,
|
||||
options: options,
|
||||
timeout: time.Duration(10) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *aria2Client) CreateTask(ctx context.Context, url string, options map[string]interface{}) (*downloader.TaskHandle, error) {
|
||||
caller := a.caller
|
||||
if caller == nil {
|
||||
var err error
|
||||
caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create rpc client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
path := a.tempPath(ctx)
|
||||
a.l.Info("Creating aria2 task with url %q saving to %q...", url, path)
|
||||
|
||||
// Create the download task options
|
||||
downloadOptions := map[string]interface{}{}
|
||||
for k, v := range a.options.Options {
|
||||
downloadOptions[k] = v
|
||||
}
|
||||
for k, v := range options {
|
||||
downloadOptions[k] = v
|
||||
}
|
||||
downloadOptions["dir"] = path
|
||||
downloadOptions["follow-torrent"] = "mem"
|
||||
|
||||
gid, err := caller.AddURI(url, downloadOptions)
|
||||
if err != nil || gid == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &downloader.TaskHandle{
|
||||
ID: gid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *aria2Client) Info(ctx context.Context, handle *downloader.TaskHandle) (*downloader.TaskStatus, error) {
|
||||
caller := a.caller
|
||||
if caller == nil {
|
||||
var err error
|
||||
caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot create rpc client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
status, err := caller.TellStatus(handle.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aria2 rpc error: %w", err)
|
||||
}
|
||||
|
||||
state := downloader.StatusDownloading
|
||||
switch status.Status {
|
||||
case "active":
|
||||
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
|
||||
state = downloader.StatusSeeding
|
||||
} else {
|
||||
state = downloader.StatusDownloading
|
||||
}
|
||||
case "waiting", "paused":
|
||||
state = downloader.StatusDownloading
|
||||
case "complete":
|
||||
state = downloader.StatusCompleted
|
||||
case "error":
|
||||
state = downloader.StatusError
|
||||
case "cancelled", "removed":
|
||||
a.l.Debug("Task %q is cancelled", handle.ID)
|
||||
return nil, fmt.Errorf("Task canceled: %w", downloader.ErrTaskNotFount)
|
||||
}
|
||||
|
||||
totalLength, _ := strconv.ParseInt(status.TotalLength, 10, 64)
|
||||
downloaded, _ := strconv.ParseInt(status.CompletedLength, 10, 64)
|
||||
downloadSpeed, _ := strconv.ParseInt(status.DownloadSpeed, 10, 64)
|
||||
uploaded, _ := strconv.ParseInt(status.UploadLength, 10, 64)
|
||||
uploadSpeed, _ := strconv.ParseInt(status.UploadSpeed, 10, 64)
|
||||
numPieces, _ := strconv.Atoi(status.NumPieces)
|
||||
savePath := filepath.ToSlash(status.Dir)
|
||||
|
||||
res := &downloader.TaskStatus{
|
||||
State: state,
|
||||
Name: status.BitTorrent.Info.Name,
|
||||
Total: totalLength,
|
||||
Downloaded: downloaded,
|
||||
DownloadSpeed: downloadSpeed,
|
||||
Uploaded: uploaded,
|
||||
UploadSpeed: uploadSpeed,
|
||||
SavePath: savePath,
|
||||
NumPieces: numPieces,
|
||||
Hash: status.InfoHash,
|
||||
Files: lo.Map(status.Files, func(item rpc.FileInfo, index int) downloader.TaskFile {
|
||||
index, _ = strconv.Atoi(item.Index)
|
||||
size, _ := strconv.ParseInt(item.Length, 10, 64)
|
||||
completed, _ := strconv.ParseInt(item.CompletedLength, 10, 64)
|
||||
relPath := strings.TrimPrefix(filepath.ToSlash(item.Path), savePath)
|
||||
// Remove first letter if any
|
||||
if len(relPath) > 0 {
|
||||
relPath = relPath[1:]
|
||||
}
|
||||
progress := 0.0
|
||||
if size > 0 {
|
||||
progress = float64(completed) / float64(size)
|
||||
}
|
||||
return downloader.TaskFile{
|
||||
Index: index,
|
||||
Name: relPath,
|
||||
Size: size,
|
||||
Progress: progress,
|
||||
Selected: item.Selected == "true",
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
if len(status.FollowedBy) > 0 {
|
||||
res.FollowedBy = &downloader.TaskHandle{
|
||||
ID: status.FollowedBy[0],
|
||||
}
|
||||
}
|
||||
|
||||
if len(status.Files) == 1 && res.Name == "" {
|
||||
res.Name = path.Base(filepath.ToSlash(status.Files[0].Path))
|
||||
}
|
||||
|
||||
if status.BitField != "" {
|
||||
res.Pieces = make([]byte, len(status.BitField)/2)
|
||||
// Convert hex string to bytes
|
||||
for i := 0; i < len(status.BitField); i += 2 {
|
||||
b, _ := strconv.ParseInt(status.BitField[i:i+2], 16, 8)
|
||||
res.Pieces[i/2] = byte(b)
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *aria2Client) Cancel(ctx context.Context, handle *downloader.TaskHandle) error {
|
||||
caller := a.caller
|
||||
if caller == nil {
|
||||
var err error
|
||||
caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create rpc client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
status, err := a.Info(ctx, handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get task: %w", err)
|
||||
}
|
||||
|
||||
// Delay to delete temp download folder to avoid being locked by aria2
|
||||
defer func() {
|
||||
go func(parent string, l logging.Logger) {
|
||||
time.Sleep(deleteTempFileDuration)
|
||||
err := os.RemoveAll(parent)
|
||||
if err != nil {
|
||||
l.Warning("Failed to delete temp download folder: %q: %s", parent, err)
|
||||
}
|
||||
}(status.SavePath, a.l)
|
||||
}()
|
||||
|
||||
if _, err := caller.Remove(handle.ID); err != nil {
|
||||
return fmt.Errorf("aria2 rpc error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *aria2Client) SetFilesToDownload(ctx context.Context, handle *downloader.TaskHandle, args ...*downloader.SetFileToDownloadArgs) error {
|
||||
caller := a.caller
|
||||
if caller == nil {
|
||||
var err error
|
||||
caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create rpc client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
status, err := a.Info(ctx, handle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get task: %w", err)
|
||||
}
|
||||
|
||||
selected := lo.SliceToMap(status.Files, func(item downloader.TaskFile) (int, bool) {
|
||||
return item.Index, true
|
||||
})
|
||||
for _, arg := range args {
|
||||
if !arg.Download {
|
||||
delete(selected, arg.Index)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = caller.ChangeOption(handle.ID, map[string]interface{}{"select-file": strings.Join(lo.MapToSlice(selected, func(key int, value bool) string {
|
||||
return strconv.Itoa(key)
|
||||
}), ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *aria2Client) Test(ctx context.Context) (string, error) {
|
||||
caller := a.caller
|
||||
if caller == nil {
|
||||
var err error
|
||||
caller, err = rpc.New(ctx, a.options.Server, a.options.Token, a.timeout, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot create rpc client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
version, err := caller.GetVersion()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot call aria2: %w", err)
|
||||
}
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
func (a *aria2Client) tempPath(ctx context.Context) string {
|
||||
guid, _ := uuid.NewV4()
|
||||
|
||||
// Generate a unique path for the task
|
||||
base := util.RelativePath(a.options.TempPath)
|
||||
if a.options.TempPath == "" {
|
||||
base = util.DataPath(a.settings.TempPath(ctx))
|
||||
}
|
||||
path := filepath.Join(
|
||||
base,
|
||||
Aria2TempFolder,
|
||||
guid.String(),
|
||||
)
|
||||
return path
|
||||
}
|
||||
@@ -268,8 +268,9 @@ func (c *client) TellStatus(gid string, keys ...string) (info StatusInfo, err er
|
||||
// `aria2.getUris([secret, ]gid)`
|
||||
// This method returns the URIs used in the download denoted by gid (string).
|
||||
// The response is an array of structs and it contains following keys. Values are string.
|
||||
// uri URI
|
||||
// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue.
|
||||
//
|
||||
// uri URI
|
||||
// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue.
|
||||
func (c *client) GetURIs(gid string) (infos []URIInfo, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
@@ -456,12 +457,14 @@ func (c *client) GetOption(gid string) (m Option, err error) {
|
||||
// `aria2.changeOption([secret, ]gid, options)`
|
||||
// This method changes options of the download denoted by gid (string) dynamically. options is a struct.
|
||||
// The following options are available for active downloads:
|
||||
// bt-max-peers
|
||||
// bt-request-peer-speed-limit
|
||||
// bt-remove-unselected-file
|
||||
// force-save
|
||||
// max-download-limit
|
||||
// max-upload-limit
|
||||
//
|
||||
// bt-max-peers
|
||||
// bt-request-peer-speed-limit
|
||||
// bt-remove-unselected-file
|
||||
// force-save
|
||||
// max-download-limit
|
||||
// max-upload-limit
|
||||
//
|
||||
// For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option.
|
||||
// This method returns OK for success.
|
||||
func (c *client) ChangeOption(gid string, option Option) (ok string, err error) {
|
||||
@@ -496,17 +499,19 @@ func (c *client) GetGlobalOption() (m Option, err error) {
|
||||
// This method changes global options dynamically.
|
||||
// options is a struct.
|
||||
// The following options are available:
|
||||
// bt-max-open-files
|
||||
// download-result
|
||||
// log
|
||||
// log-level
|
||||
// max-concurrent-downloads
|
||||
// max-download-result
|
||||
// max-overall-download-limit
|
||||
// max-overall-upload-limit
|
||||
// save-cookies
|
||||
// save-session
|
||||
// server-stat-of
|
||||
//
|
||||
// bt-max-open-files
|
||||
// download-result
|
||||
// log
|
||||
// log-level
|
||||
// max-concurrent-downloads
|
||||
// max-download-result
|
||||
// max-overall-download-limit
|
||||
// max-overall-upload-limit
|
||||
// save-cookies
|
||||
// save-session
|
||||
// server-stat-of
|
||||
//
|
||||
// In addition, options listed in the Input File subsection are available, except for following options: checksum, index-out, out, pause and select-file.
|
||||
// With the log option, you can dynamically start logging or change log file.
|
||||
// To stop logging, specify an empty string("") as the parameter value.
|
||||
@@ -525,13 +530,14 @@ func (c *client) ChangeGlobalOption(options Option) (ok string, err error) {
|
||||
// `aria2.getGlobalStat([secret])`
|
||||
// This method returns global statistics such as the overall download and upload speeds.
|
||||
// The response is a struct and contains the following keys. Values are strings.
|
||||
// downloadSpeed Overall download speed (byte/sec).
|
||||
// uploadSpeed Overall upload speed(byte/sec).
|
||||
// numActive The number of active downloads.
|
||||
// numWaiting The number of waiting downloads.
|
||||
// numStopped The number of stopped downloads in the current session.
|
||||
// This value is capped by the --max-download-result option.
|
||||
// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option.
|
||||
//
|
||||
// downloadSpeed Overall download speed (byte/sec).
|
||||
// uploadSpeed Overall upload speed(byte/sec).
|
||||
// numActive The number of active downloads.
|
||||
// numWaiting The number of waiting downloads.
|
||||
// numStopped The number of stopped downloads in the current session.
|
||||
// This value is capped by the --max-download-result option.
|
||||
// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option.
|
||||
func (c *client) GetGlobalStat() (info GlobalStatInfo, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
@@ -569,8 +575,9 @@ func (c *client) RemoveDownloadResult(gid string) (ok string, err error) {
|
||||
// `aria2.getVersion([secret])`
|
||||
// This method returns the version of aria2 and the list of enabled features.
|
||||
// The response is a struct and contains following keys.
|
||||
// version Version number of aria2 as a string.
|
||||
// enabledFeatures List of enabled features. Each feature is given as a string.
|
||||
//
|
||||
// version Version number of aria2 as a string.
|
||||
// enabledFeatures List of enabled features. Each feature is given as a string.
|
||||
func (c *client) GetVersion() (info VersionInfo, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
@@ -583,7 +590,8 @@ func (c *client) GetVersion() (info VersionInfo, err error) {
|
||||
// `aria2.getSessionInfo([secret])`
|
||||
// This method returns session information.
|
||||
// The response is a struct and contains following key.
|
||||
// sessionId Session ID, which is generated each time when aria2 is invoked.
|
||||
//
|
||||
// sessionId Session ID, which is generated each time when aria2 is invoked.
|
||||
func (c *client) GetSessionInfo() (info SessionInfo, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
76
pkg/downloader/downloader.go
Normal file
76
pkg/downloader/downloader.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTaskNotFount = fmt.Errorf("task not found")
|
||||
)
|
||||
|
||||
type (
|
||||
Downloader interface {
|
||||
// Create a task with the given URL and options overwriting the default settings, returns a task handle for future operations.
|
||||
CreateTask(ctx context.Context, url string, options map[string]interface{}) (*TaskHandle, error)
|
||||
// Info returns the status of the task with the given handle.
|
||||
Info(ctx context.Context, handle *TaskHandle) (*TaskStatus, error)
|
||||
// Cancel the task with the given handle.
|
||||
Cancel(ctx context.Context, handle *TaskHandle) error
|
||||
// SetFilesToDownload sets the files to download for the task with the given handle.
|
||||
SetFilesToDownload(ctx context.Context, handle *TaskHandle, args ...*SetFileToDownloadArgs) error
|
||||
// Test tests the connection to the downloader.
|
||||
Test(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// TaskHandle represents a task handle for future operations
|
||||
TaskHandle struct {
|
||||
ID string `json:"id"`
|
||||
Hash string `json:"hash"`
|
||||
}
|
||||
Status string
|
||||
TaskStatus struct {
|
||||
FollowedBy *TaskHandle `json:"-"` // Indicate if the task handle is changed
|
||||
SavePath string `json:"save_path,omitempty"`
|
||||
Name string `json:"name"`
|
||||
State Status `json:"state"`
|
||||
Total int64 `json:"total"`
|
||||
Downloaded int64 `json:"downloaded"`
|
||||
DownloadSpeed int64 `json:"download_speed"`
|
||||
Uploaded int64 `json:"uploaded"`
|
||||
UploadSpeed int64 `json:"upload_speed"`
|
||||
Hash string `json:"hash,omitempty"`
|
||||
Files []TaskFile `json:"files,omitempty"`
|
||||
Pieces []byte `json:"pieces,omitempty"` // Hexadecimal representation of the download progress of the peer. The highest bit corresponds to the piece at index 0.
|
||||
NumPieces int `json:"num_pieces,omitempty"`
|
||||
}
|
||||
|
||||
TaskFile struct {
|
||||
Index int `json:"index"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Progress float64 `json:"progress"`
|
||||
Selected bool `json:"selected"`
|
||||
}
|
||||
|
||||
SetFileToDownloadArgs struct {
|
||||
Index int `json:"index"`
|
||||
Download bool `json:"download"`
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
StatusDownloading Status = "downloading"
|
||||
StatusSeeding Status = "seeding"
|
||||
StatusCompleted Status = "completed"
|
||||
StatusError Status = "error"
|
||||
StatusUnknown Status = "unknown"
|
||||
|
||||
DownloaderCtxKey = "downloader"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(TaskHandle{})
|
||||
gob.Register(TaskStatus{})
|
||||
}
|
||||
395
pkg/downloader/qbittorrent/qbittorrent.go
Normal file
395
pkg/downloader/qbittorrent/qbittorrent.go
Normal file
@@ -0,0 +1,395 @@
|
||||
package qbittorrent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefix = "/api/v2"
|
||||
successResponse = "Ok."
|
||||
crTagPrefix = "cr-"
|
||||
|
||||
downloadPrioritySkip = 0
|
||||
downloadPriorityDownload = 1
|
||||
)
|
||||
|
||||
var (
|
||||
supportDownloadOptions = map[string]bool{
|
||||
"cookie": true,
|
||||
"skip_checking": true,
|
||||
"root_folder": true,
|
||||
"rename": true,
|
||||
"upLimit": true,
|
||||
"dlLimit": true,
|
||||
"ratioLimit": true,
|
||||
"seedingTimeLimit": true,
|
||||
"autoTMM": true,
|
||||
"sequentialDownload": true,
|
||||
"firstLastPiecePrio": true,
|
||||
}
|
||||
)
|
||||
|
||||
type qbittorrentClient struct {
|
||||
c request.Client
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
options *types.QBittorrentSetting
|
||||
}
|
||||
|
||||
func NewClient(l logging.Logger, c request.Client, setting setting.Provider, options *types.QBittorrentSetting) (downloader.Downloader, error) {
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server, err := url.Parse(options.Server)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid qbittorrent server URL: %w", err)
|
||||
}
|
||||
|
||||
base, _ := url.Parse(apiPrefix)
|
||||
c.Apply(
|
||||
request.WithCookieJar(jar),
|
||||
request.WithLogger(l),
|
||||
request.WithEndpoint(options.Server),
|
||||
request.WithEndpoint(server.ResolveReference(base).String()),
|
||||
)
|
||||
return &qbittorrentClient{c: c, options: options, l: l, settings: setting}, nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) SetFilesToDownload(ctx context.Context, handle *downloader.TaskHandle, args ...*downloader.SetFileToDownloadArgs) error {
|
||||
downloadId := make([]int, 0, len(args))
|
||||
skipId := make([]int, 0, len(args))
|
||||
for _, arg := range args {
|
||||
if arg.Download {
|
||||
downloadId = append(downloadId, arg.Index)
|
||||
} else {
|
||||
skipId = append(skipId, arg.Index)
|
||||
}
|
||||
}
|
||||
|
||||
if len(downloadId) > 0 {
|
||||
if err := c.setFilePriority(ctx, handle.Hash, downloadPriorityDownload, downloadId...); err != nil {
|
||||
return fmt.Errorf("failed to set file priority to download: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(skipId) > 0 {
|
||||
if err := c.setFilePriority(ctx, handle.Hash, downloadPrioritySkip, skipId...); err != nil {
|
||||
return fmt.Errorf("failed to set file priority to skip: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) Cancel(ctx context.Context, handle *downloader.TaskHandle) error {
|
||||
buffer := bytes.Buffer{}
|
||||
formWriter := multipart.NewWriter(&buffer)
|
||||
_ = formWriter.WriteField("hashes", handle.Hash)
|
||||
_ = formWriter.WriteField("deleteFiles", "true")
|
||||
|
||||
headers := http.Header{
|
||||
"Content-Type": []string{formWriter.FormDataContentType()},
|
||||
}
|
||||
|
||||
_, err := c.request(ctx, http.MethodPost, "torrents/delete", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cancel task with hash %q: %w", handle.Hash, err)
|
||||
}
|
||||
|
||||
// Delete tags
|
||||
buffer = bytes.Buffer{}
|
||||
formWriter = multipart.NewWriter(&buffer)
|
||||
_ = formWriter.WriteField("tags", crTagPrefix+handle.ID)
|
||||
|
||||
headers = http.Header{
|
||||
"Content-Type": []string{formWriter.FormDataContentType()},
|
||||
}
|
||||
|
||||
_, err = c.request(ctx, http.MethodPost, "torrents/deleteTags", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete tag with id %q: %w", handle.ID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) Info(ctx context.Context, handle *downloader.TaskHandle) (*downloader.TaskStatus, error) {
|
||||
buffer := bytes.Buffer{}
|
||||
formWriter := multipart.NewWriter(&buffer)
|
||||
_ = formWriter.WriteField("tag", crTagPrefix+handle.ID)
|
||||
|
||||
headers := http.Header{
|
||||
"Content-Type": []string{formWriter.FormDataContentType()},
|
||||
}
|
||||
|
||||
// Get task info
|
||||
resp, err := c.request(ctx, http.MethodPost, "torrents/info", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get task info with tag %q: %w", crTagPrefix+handle.ID, err)
|
||||
}
|
||||
|
||||
var torrents []Torrent
|
||||
if err := json.Unmarshal([]byte(resp), &torrents); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal info response: %w", err)
|
||||
}
|
||||
|
||||
if len(torrents) == 0 {
|
||||
return nil, fmt.Errorf("no torrent under tag %q: %w", crTagPrefix+handle.ID, downloader.ErrTaskNotFount)
|
||||
}
|
||||
|
||||
// Get file info
|
||||
buffer = bytes.Buffer{}
|
||||
formWriter = multipart.NewWriter(&buffer)
|
||||
_ = formWriter.WriteField("hash", torrents[0].Hash)
|
||||
headers = http.Header{
|
||||
"Content-Type": []string{formWriter.FormDataContentType()},
|
||||
}
|
||||
|
||||
resp, err = c.request(ctx, http.MethodPost, "torrents/files", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get torrent files with hash %q: %w", torrents[0].Hash, err)
|
||||
}
|
||||
|
||||
var files []File
|
||||
if err := json.Unmarshal([]byte(resp), &files); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal files response: %w", err)
|
||||
}
|
||||
|
||||
// Get piece status
|
||||
resp, err = c.request(ctx, http.MethodPost, "torrents/pieceStates", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get torrent pieceStates with hash %q: %w", torrents[0].Hash, err)
|
||||
}
|
||||
|
||||
var pieceStates []int
|
||||
if err := json.Unmarshal([]byte(resp), &pieceStates); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal pieceStates response: %w", err)
|
||||
}
|
||||
|
||||
// Combining and converting all info
|
||||
state := downloader.StatusDownloading
|
||||
switch torrents[0].State {
|
||||
case "downloading", "pausedDL", "allocating", "metaDL", "queuedDL", "stalledDL", "checkingDL", "forcedDL", "checkingResumeData", "moving":
|
||||
state = downloader.StatusDownloading
|
||||
case "uploading", "queuedUP", "stalledUP", "checkingUP":
|
||||
state = downloader.StatusSeeding
|
||||
case "pausedUP":
|
||||
state = downloader.StatusCompleted
|
||||
case "error", "missingFiles":
|
||||
state = downloader.StatusError
|
||||
default:
|
||||
state = downloader.StatusUnknown
|
||||
}
|
||||
status := &downloader.TaskStatus{
|
||||
Name: torrents[0].Name,
|
||||
Total: torrents[0].Size,
|
||||
Downloaded: torrents[0].Completed,
|
||||
DownloadSpeed: torrents[0].Dlspeed,
|
||||
Uploaded: torrents[0].Uploaded,
|
||||
UploadSpeed: torrents[0].Upspeed,
|
||||
SavePath: filepath.ToSlash(torrents[0].SavePath),
|
||||
State: state,
|
||||
Hash: torrents[0].Hash,
|
||||
Files: lo.Map(files, func(item File, index int) downloader.TaskFile {
|
||||
return downloader.TaskFile{
|
||||
Index: item.Index,
|
||||
Name: filepath.ToSlash(item.Name),
|
||||
Size: item.Size,
|
||||
Progress: item.Progress,
|
||||
Selected: item.Priority > 0,
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
if handle.Hash != torrents[0].Hash {
|
||||
handle.Hash = torrents[0].Hash
|
||||
status.FollowedBy = handle
|
||||
}
|
||||
|
||||
// Convert piece states to hex bytes array, The highest bit corresponds to the piece at index 0.
|
||||
status.NumPieces = len(pieceStates)
|
||||
pieces := make([]byte, 0, len(pieceStates)/8+1)
|
||||
for i := 0; i < len(pieceStates); i += 8 {
|
||||
var b byte
|
||||
for j := 0; j < 8; j++ {
|
||||
if i+j >= len(pieceStates) {
|
||||
break
|
||||
}
|
||||
pieceStatus := 0
|
||||
if pieceStates[i+j] == 2 {
|
||||
pieceStatus = 1
|
||||
}
|
||||
b |= byte(pieceStatus) << uint(7-j)
|
||||
}
|
||||
pieces = append(pieces, b)
|
||||
}
|
||||
status.Pieces = pieces
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) CreateTask(ctx context.Context, url string, options map[string]interface{}) (*downloader.TaskHandle, error) {
|
||||
guid, _ := uuid.NewV4()
|
||||
|
||||
// Generate a unique path for the task
|
||||
base := util.RelativePath(c.options.TempPath)
|
||||
if c.options.TempPath == "" {
|
||||
base = util.DataPath(c.settings.TempPath(ctx))
|
||||
}
|
||||
path := filepath.Join(
|
||||
base,
|
||||
"qbittorrent",
|
||||
guid.String(),
|
||||
)
|
||||
c.l.Info("Creating QBitTorrent task with url %q saving to %q...", url, path)
|
||||
|
||||
buffer := bytes.Buffer{}
|
||||
formWriter := multipart.NewWriter(&buffer)
|
||||
_ = formWriter.WriteField("urls", url)
|
||||
_ = formWriter.WriteField("savepath", path)
|
||||
_ = formWriter.WriteField("tags", crTagPrefix+guid.String())
|
||||
|
||||
// Apply global options
|
||||
for k, v := range c.options.Options {
|
||||
if _, ok := supportDownloadOptions[k]; ok {
|
||||
_ = formWriter.WriteField(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply group options
|
||||
for k, v := range options {
|
||||
if _, ok := supportDownloadOptions[k]; ok {
|
||||
_ = formWriter.WriteField(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
|
||||
// Send request
|
||||
headers := http.Header{
|
||||
"Content-Type": []string{formWriter.FormDataContentType()},
|
||||
}
|
||||
|
||||
resp, err := c.request(ctx, http.MethodPost, "torrents/add", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create task qbittorrent failed: %w", err)
|
||||
}
|
||||
|
||||
if resp != successResponse {
|
||||
return nil, fmt.Errorf("create task qbittorrent failed: %s", resp)
|
||||
}
|
||||
|
||||
return &downloader.TaskHandle{
|
||||
ID: guid.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) setFilePriority(ctx context.Context, hash string, priority int, id ...int) error {
|
||||
buffer := bytes.Buffer{}
|
||||
formWriter := multipart.NewWriter(&buffer)
|
||||
_ = formWriter.WriteField("hash", hash)
|
||||
_ = formWriter.WriteField("id", strings.Join(
|
||||
lo.Map(id, func(item int, index int) string {
|
||||
return fmt.Sprintf("%d", item)
|
||||
}), "|"))
|
||||
_ = formWriter.WriteField("priority", fmt.Sprintf("%d", priority))
|
||||
|
||||
headers := http.Header{
|
||||
"Content-Type": []string{formWriter.FormDataContentType()},
|
||||
}
|
||||
|
||||
_, err := c.request(ctx, http.MethodPost, "torrents/filePrio", buffer.String(), &headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set file priority: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) Test(ctx context.Context) (string, error) {
|
||||
res, err := c.request(ctx, http.MethodGet, "app/version", "", nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("test qbittorrent failed: %w", err)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) login(ctx context.Context) error {
|
||||
form := url.Values{}
|
||||
form.Add("username", c.options.User)
|
||||
form.Add("password", c.options.Password)
|
||||
res, err := c.c.Request(http.MethodPost, "auth/login",
|
||||
strings.NewReader(form.Encode()),
|
||||
request.WithContext(ctx),
|
||||
request.WithHeader(http.Header{
|
||||
"Content-Type": []string{"application/x-www-form-urlencoded"},
|
||||
}),
|
||||
).CheckHTTPResponse(http.StatusOK).GetResponse()
|
||||
if err != nil {
|
||||
return fmt.Errorf("login failed with unexpected status code: %w", err)
|
||||
}
|
||||
|
||||
if res != successResponse {
|
||||
return fmt.Errorf("login failed with response: %s, possibly inccorrect credential is provided", res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *qbittorrentClient) request(ctx context.Context, method, path string, body string, headers *http.Header) (string, error) {
|
||||
opts := []request.Option{
|
||||
request.WithContext(ctx),
|
||||
}
|
||||
|
||||
if headers != nil {
|
||||
opts = append(opts, request.WithHeader(*headers))
|
||||
}
|
||||
|
||||
res := c.c.Request(method, path, strings.NewReader(body), opts...)
|
||||
|
||||
if res.Err != nil {
|
||||
return "", fmt.Errorf("send request failed: %w", res.Err)
|
||||
}
|
||||
|
||||
switch res.Response.StatusCode {
|
||||
case http.StatusForbidden:
|
||||
c.l.Info("QBittorrent cookie expired, sending login request...")
|
||||
if err := c.login(ctx); err != nil {
|
||||
return "", fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
return c.request(ctx, method, path, body, headers)
|
||||
|
||||
case http.StatusOK:
|
||||
respContent, err := res.GetResponse()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed reading response: %w", err)
|
||||
}
|
||||
|
||||
return respContent, nil
|
||||
case http.StatusUnsupportedMediaType:
|
||||
return "", fmt.Errorf("invalid torrent file")
|
||||
default:
|
||||
content, _ := res.GetResponse()
|
||||
return "", fmt.Errorf("unexpected status code: %d, content: %s", res.Response.StatusCode, content)
|
||||
}
|
||||
}
|
||||
64
pkg/downloader/qbittorrent/types.go
Normal file
64
pkg/downloader/qbittorrent/types.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package qbittorrent
|
||||
|
||||
type Torrent struct {
|
||||
AddedOn int64 `json:"added_on"`
|
||||
AmountLeft int64 `json:"amount_left"`
|
||||
AutoTmm bool `json:"auto_tmm"`
|
||||
Availability float64 `json:"availability"`
|
||||
Category string `json:"category"`
|
||||
Completed int64 `json:"completed"`
|
||||
CompletionOn int64 `json:"completion_on"`
|
||||
ContentPath string `json:"content_path"`
|
||||
DlLimit int `json:"dl_limit"`
|
||||
Dlspeed int64 `json:"dlspeed"`
|
||||
DownloadPath string `json:"download_path"`
|
||||
Downloaded int64 `json:"downloaded"`
|
||||
DownloadedSession int `json:"downloaded_session"`
|
||||
Eta int `json:"eta"`
|
||||
FLPiecePrio bool `json:"f_l_piece_prio"`
|
||||
ForceStart bool `json:"force_start"`
|
||||
Hash string `json:"hash"`
|
||||
InfohashV1 string `json:"infohash_v1"`
|
||||
InfohashV2 string `json:"infohash_v2"`
|
||||
LastActivity int `json:"last_activity"`
|
||||
MagnetUri string `json:"magnet_uri"`
|
||||
MaxRatio float64 `json:"max_ratio"`
|
||||
MaxSeedingTime int `json:"max_seeding_time"`
|
||||
Name string `json:"name"`
|
||||
NumComplete int `json:"num_complete"`
|
||||
NumIncomplete int `json:"num_incomplete"`
|
||||
NumLeechs int `json:"num_leechs"`
|
||||
NumSeeds int `json:"num_seeds"`
|
||||
Priority int `json:"priority"`
|
||||
Progress float64 `json:"progress"`
|
||||
Ratio float64 `json:"ratio"`
|
||||
RatioLimit float64 `json:"ratio_limit"`
|
||||
SavePath string `json:"save_path"`
|
||||
SeedingTime int `json:"seeding_time"`
|
||||
SeedingTimeLimit int `json:"seeding_time_limit"`
|
||||
SeenComplete int `json:"seen_complete"`
|
||||
SeqDl bool `json:"seq_dl"`
|
||||
Size int64 `json:"size"`
|
||||
State string `json:"state"`
|
||||
SuperSeeding bool `json:"super_seeding"`
|
||||
Tags string `json:"tags"`
|
||||
TimeActive int `json:"time_active"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
Tracker string `json:"tracker"`
|
||||
TrackersCount int `json:"trackers_count"`
|
||||
UpLimit int `json:"up_limit"`
|
||||
Uploaded int64 `json:"uploaded"`
|
||||
UploadedSession int `json:"uploaded_session"`
|
||||
Upspeed int64 `json:"upspeed"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Index int `json:"index"`
|
||||
IsSeed bool `json:"is_seed"`
|
||||
Name string `json:"name"`
|
||||
PieceRange []int `json:"piece_range"`
|
||||
Priority int `json:"priority"`
|
||||
Progress float64 `json:"progress"`
|
||||
Size int64 `json:"size"`
|
||||
Availability float64 `json:"availability"`
|
||||
}
|
||||
258
pkg/downloader/slave/slave.go
Normal file
258
pkg/downloader/slave/slave.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package slave
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cfssl/scan/crypto/sha1"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
)
|
||||
|
||||
type slaveDownloader struct {
|
||||
client request.Client
|
||||
nodeSetting *types.NodeSetting
|
||||
nodeSettingHash string
|
||||
}
|
||||
|
||||
// NewSlaveDownloader creates a new slave downloader
|
||||
func NewSlaveDownloader(client request.Client, nodeSetting *types.NodeSetting) downloader.Downloader {
|
||||
nodeSettingJson, err := json.Marshal(nodeSetting)
|
||||
if err != nil {
|
||||
nodeSettingJson = []byte{}
|
||||
}
|
||||
|
||||
return &slaveDownloader{
|
||||
client: client,
|
||||
nodeSetting: nodeSetting,
|
||||
nodeSettingHash: fmt.Sprintf("%x", sha1.Sum(nodeSettingJson)),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slaveDownloader) CreateTask(ctx context.Context, url string, options map[string]interface{}) (*downloader.TaskHandle, error) {
|
||||
reqBody, err := json.Marshal(&CreateSlaveDownload{
|
||||
NodeSetting: s.nodeSetting,
|
||||
Url: url,
|
||||
Options: options,
|
||||
NodeSettingHash: s.nodeSettingHash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/download/task",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
var taskHandle *downloader.TaskHandle
|
||||
if resp.GobDecode(&taskHandle); taskHandle != nil {
|
||||
return taskHandle, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected response data: %v", resp.Data)
|
||||
}
|
||||
|
||||
func (s *slaveDownloader) Info(ctx context.Context, handle *downloader.TaskHandle) (*downloader.TaskStatus, error) {
|
||||
reqBody, err := json.Marshal(&GetSlaveDownload{
|
||||
NodeSetting: s.nodeSetting,
|
||||
Handle: handle,
|
||||
NodeSettingHash: s.nodeSettingHash,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/download/status",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
err = serializer.NewErrorFromResponse(resp)
|
||||
if strings.Contains(err.Error(), downloader.ErrTaskNotFount.Error()) {
|
||||
return nil, fmt.Errorf("%s (%w)", err.Error(), downloader.ErrTaskNotFount)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var taskStatus *downloader.TaskStatus
|
||||
if resp.GobDecode(&taskStatus); taskStatus != nil {
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected response data: %v", resp.Data)
|
||||
}
|
||||
|
||||
func (s *slaveDownloader) Cancel(ctx context.Context, handle *downloader.TaskHandle) error {
|
||||
reqBody, err := json.Marshal(&CancelSlaveDownload{
|
||||
NodeSetting: s.nodeSetting,
|
||||
Handle: handle,
|
||||
NodeSettingHash: s.nodeSettingHash,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/download/cancel",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveDownloader) SetFilesToDownload(ctx context.Context, handle *downloader.TaskHandle, args ...*downloader.SetFileToDownloadArgs) error {
|
||||
reqBody, err := json.Marshal(&SetSlaveFilesToDownload{
|
||||
NodeSetting: s.nodeSetting,
|
||||
Handle: handle,
|
||||
NodeSettingHash: s.nodeSettingHash,
|
||||
Args: args,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/download/select",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveDownloader) Test(ctx context.Context) (string, error) {
|
||||
reqBody, err := json.Marshal(&TestSlaveDownload{
|
||||
NodeSetting: s.nodeSetting,
|
||||
NodeSettingHash: s.nodeSettingHash,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Request(
|
||||
"POST",
|
||||
constants.APIPrefixSlave+"/download/test",
|
||||
bytes.NewReader(reqBody),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return resp.Data.(string), nil
|
||||
}
|
||||
|
||||
// Slave remote download related
|
||||
type (
|
||||
// Request body for creating tasks on slave node
|
||||
CreateSlaveDownload struct {
|
||||
NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"`
|
||||
NodeSettingHash string `json:"node_setting_hash" binding:"required"`
|
||||
Url string `json:"url" binding:"required"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
// Request body for get download task info from slave node
|
||||
GetSlaveDownload struct {
|
||||
Handle *downloader.TaskHandle `json:"handle" binding:"required"`
|
||||
NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"`
|
||||
NodeSettingHash string `json:"node_setting_hash" binding:"required"`
|
||||
}
|
||||
|
||||
// Request body for cancel download task on slave node
|
||||
CancelSlaveDownload struct {
|
||||
Handle *downloader.TaskHandle `json:"handle" binding:"required"`
|
||||
NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"`
|
||||
NodeSettingHash string `json:"node_setting_hash" binding:"required"`
|
||||
}
|
||||
|
||||
// Request body for selecting files to download on slave node
|
||||
SetSlaveFilesToDownload struct {
|
||||
Handle *downloader.TaskHandle `json:"handle" binding:"required"`
|
||||
Args []*downloader.SetFileToDownloadArgs `json:"args" binding:"required"`
|
||||
NodeSettingHash string `json:"node_setting_hash" binding:"required"`
|
||||
NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"`
|
||||
}
|
||||
|
||||
TestSlaveDownload struct {
|
||||
NodeSetting *types.NodeSetting `json:"node_setting" binding:"required"`
|
||||
NodeSettingHash string `json:"node_setting_hash" binding:"required"`
|
||||
}
|
||||
)
|
||||
|
||||
// GetNodeSetting implements SlaveNodeSettingGetter interface
|
||||
func (d *CreateSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) {
|
||||
return d.NodeSetting, d.NodeSettingHash
|
||||
}
|
||||
|
||||
// GetNodeSetting implements SlaveNodeSettingGetter interface
|
||||
func (d *GetSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) {
|
||||
return d.NodeSetting, d.NodeSettingHash
|
||||
}
|
||||
|
||||
// GetNodeSetting implements SlaveNodeSettingGetter interface
|
||||
func (d *CancelSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) {
|
||||
return d.NodeSetting, d.NodeSettingHash
|
||||
}
|
||||
|
||||
// GetNodeSetting implements SlaveNodeSettingGetter interface
|
||||
func (d *SetSlaveFilesToDownload) GetNodeSetting() (*types.NodeSetting, string) {
|
||||
return d.NodeSetting, d.NodeSettingHash
|
||||
}
|
||||
|
||||
// GetNodeSetting implements SlaveNodeSettingGetter interface
|
||||
func (d *TestSlaveDownload) GetNodeSetting() (*types.NodeSetting, string) {
|
||||
return d.NodeSetting, d.NodeSettingHash
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Client 默认的邮件发送客户端
|
||||
var Client Driver
|
||||
|
||||
// Lock 读写锁
|
||||
var Lock sync.RWMutex
|
||||
|
||||
// Init 初始化
|
||||
func Init() {
|
||||
util.Log().Debug("Initializing email sending queue...")
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
if Client != nil {
|
||||
Client.Close()
|
||||
}
|
||||
|
||||
// 读取SMTP设置
|
||||
options := model.GetSettingByNames(
|
||||
"fromName",
|
||||
"fromAdress",
|
||||
"smtpHost",
|
||||
"replyTo",
|
||||
"smtpUser",
|
||||
"smtpPass",
|
||||
"smtpEncryption",
|
||||
)
|
||||
port := model.GetIntSetting("smtpPort", 25)
|
||||
keepAlive := model.GetIntSetting("mail_keepalive", 30)
|
||||
|
||||
client := NewSMTPClient(SMTPConfig{
|
||||
Name: options["fromName"],
|
||||
Address: options["fromAdress"],
|
||||
ReplyTo: options["replyTo"],
|
||||
Host: options["smtpHost"],
|
||||
Port: port,
|
||||
User: options["smtpUser"],
|
||||
Password: options["smtpPass"],
|
||||
Keepalive: keepAlive,
|
||||
Encryption: model.IsTrueVal(options["smtpEncryption"]),
|
||||
})
|
||||
|
||||
Client = client
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Driver 邮件发送驱动
|
||||
@@ -10,7 +10,7 @@ type Driver interface {
|
||||
// Close 关闭驱动
|
||||
Close()
|
||||
// Send 发送邮件
|
||||
Send(to, title, body string) error
|
||||
Send(ctx context.Context, to, title, body string) error
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -19,20 +19,3 @@ var (
|
||||
// ErrNoActiveDriver 无可用邮件发送服务
|
||||
ErrNoActiveDriver = errors.New("no avaliable email provider")
|
||||
)
|
||||
|
||||
// Send 发送邮件
|
||||
func Send(to, title, body string) error {
|
||||
// 忽略通过QQ登录的邮箱
|
||||
if strings.HasSuffix(to, "@login.qq.com") {
|
||||
return nil
|
||||
}
|
||||
|
||||
Lock.RLock()
|
||||
defer Lock.RUnlock()
|
||||
|
||||
if Client == nil {
|
||||
return ErrNoActiveDriver
|
||||
}
|
||||
|
||||
return Client.Send(to, title, body)
|
||||
}
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/go-mail/mail"
|
||||
"github.com/gofrs/uuid"
|
||||
)
|
||||
|
||||
// SMTP SMTP协议发送邮件
|
||||
type SMTP struct {
|
||||
// SMTPPool SMTP协议发送邮件
|
||||
type SMTPPool struct {
|
||||
// Deprecated
|
||||
Config SMTPConfig
|
||||
ch chan *mail.Message
|
||||
|
||||
config *setting.SMTP
|
||||
ch chan *message
|
||||
chOpen bool
|
||||
l logging.Logger
|
||||
}
|
||||
|
||||
// SMTPConfig SMTP发送配置
|
||||
@@ -26,14 +34,34 @@ type SMTPConfig struct {
|
||||
User string // 用户名
|
||||
Password string // 密码
|
||||
Encryption bool // 是否启用加密
|
||||
Keepalive int // SMTP 连接保留时长
|
||||
Keepalive int // SMTPPool 连接保留时长
|
||||
}
|
||||
|
||||
type message struct {
|
||||
msg *mail.Message
|
||||
cid string
|
||||
userID int
|
||||
}
|
||||
|
||||
// NewSMTPPool initializes a new SMTP based email sending queue.
|
||||
func NewSMTPPool(config setting.Provider, logger logging.Logger) *SMTPPool {
|
||||
client := &SMTPPool{
|
||||
config: config.SMTP(context.Background()),
|
||||
ch: make(chan *message, 30),
|
||||
chOpen: false,
|
||||
l: logger,
|
||||
}
|
||||
|
||||
client.Init()
|
||||
return client
|
||||
}
|
||||
|
||||
// NewSMTPClient 新建SMTP发送队列
|
||||
func NewSMTPClient(config SMTPConfig) *SMTP {
|
||||
client := &SMTP{
|
||||
// Deprecated
|
||||
func NewSMTPClient(config SMTPConfig) *SMTPPool {
|
||||
client := &SMTPPool{
|
||||
Config: config,
|
||||
ch: make(chan *mail.Message, 30),
|
||||
ch: make(chan *message, 30),
|
||||
chOpen: false,
|
||||
}
|
||||
|
||||
@@ -43,46 +71,57 @@ func NewSMTPClient(config SMTPConfig) *SMTP {
|
||||
}
|
||||
|
||||
// Send 发送邮件
|
||||
func (client *SMTP) Send(to, title, body string) error {
|
||||
func (client *SMTPPool) Send(ctx context.Context, to, title, body string) error {
|
||||
if !client.chOpen {
|
||||
return ErrChanNotOpen
|
||||
return fmt.Errorf("SMTP pool is closed")
|
||||
}
|
||||
|
||||
// 忽略通过QQ登录的邮箱
|
||||
if strings.HasSuffix(to, "@login.qq.com") {
|
||||
return nil
|
||||
}
|
||||
|
||||
m := mail.NewMessage()
|
||||
m.SetAddressHeader("From", client.Config.Address, client.Config.Name)
|
||||
m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name)
|
||||
m.SetAddressHeader("From", client.config.From, client.config.FromName)
|
||||
m.SetAddressHeader("Reply-To", client.config.ReplyTo, client.config.FromName)
|
||||
m.SetHeader("To", to)
|
||||
m.SetHeader("Subject", title)
|
||||
m.SetHeader("Message-ID", fmt.Sprintf("<%s@%s>", uuid.NewString(), "cloudreve"))
|
||||
m.SetHeader("Message-ID", fmt.Sprintf("<%s@%s>", uuid.Must(uuid.NewV4()).String(), "cloudreve"))
|
||||
m.SetBody("text/html", body)
|
||||
client.ch <- m
|
||||
client.ch <- &message{
|
||||
msg: m,
|
||||
cid: logging.CorrelationID(ctx).String(),
|
||||
userID: inventory.UserIDFromContext(ctx),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭发送队列
|
||||
func (client *SMTP) Close() {
|
||||
func (client *SMTPPool) Close() {
|
||||
if client.ch != nil {
|
||||
close(client.ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Init 初始化发送队列
|
||||
func (client *SMTP) Init() {
|
||||
func (client *SMTPPool) Init() {
|
||||
go func() {
|
||||
client.l.Info("Initializing and starting SMTP email pool...")
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
client.chOpen = false
|
||||
util.Log().Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err)
|
||||
client.l.Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err)
|
||||
time.Sleep(time.Duration(10) * time.Second)
|
||||
client.Init()
|
||||
}
|
||||
}()
|
||||
|
||||
d := mail.NewDialer(client.Config.Host, client.Config.Port, client.Config.User, client.Config.Password)
|
||||
d.Timeout = time.Duration(client.Config.Keepalive+5) * time.Second
|
||||
d := mail.NewDialer(client.config.Host, client.config.Port, client.config.User, client.config.Password)
|
||||
d.Timeout = time.Duration(client.config.Keepalive+5) * time.Second
|
||||
client.chOpen = true
|
||||
// 是否启用 SSL
|
||||
d.SSL = false
|
||||
if client.Config.Encryption {
|
||||
if client.config.ForceEncryption {
|
||||
d.SSL = true
|
||||
}
|
||||
d.StartTLSPolicy = mail.OpportunisticStartTLS
|
||||
@@ -94,26 +133,29 @@ func (client *SMTP) Init() {
|
||||
select {
|
||||
case m, ok := <-client.ch:
|
||||
if !ok {
|
||||
util.Log().Debug("Email queue closing...")
|
||||
client.l.Info("Email queue closing...")
|
||||
client.chOpen = false
|
||||
return
|
||||
}
|
||||
|
||||
if !open {
|
||||
if s, err = d.Dial(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
open = true
|
||||
}
|
||||
if err := mail.Send(s, m); err != nil {
|
||||
util.Log().Warning("Failed to send email: %s", err)
|
||||
|
||||
l := client.l.CopyWithPrefix(fmt.Sprintf("[Cid: %s]", m.cid))
|
||||
if err := mail.Send(s, m.msg); err != nil {
|
||||
l.Warning("Failed to send email: %s, Cid=%s", err, m.cid)
|
||||
} else {
|
||||
util.Log().Debug("Email sent.")
|
||||
l.Info("Email sent to %q, title: %q.", m.msg.GetHeader("To"), m.msg.GetHeader("Subject"))
|
||||
}
|
||||
// 长时间没有新邮件,则关闭SMTP连接
|
||||
case <-time.After(time.Duration(client.Config.Keepalive) * time.Second):
|
||||
case <-time.After(time.Duration(client.config.Keepalive) * time.Second):
|
||||
if open {
|
||||
if err := s.Close(); err != nil {
|
||||
util.Log().Warning("Failed to close SMTP connection: %s", err)
|
||||
client.l.Warning("Failed to close SMTP connection: %s", err)
|
||||
}
|
||||
open = false
|
||||
}
|
||||
|
||||
@@ -1,36 +1,125 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
)
|
||||
|
||||
// NewActivationEmail 新建激活邮件
|
||||
func NewActivationEmail(userName, activateURL string) (string, string) {
|
||||
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template")
|
||||
replace := map[string]string{
|
||||
"{siteTitle}": options["siteName"],
|
||||
"{userName}": userName,
|
||||
"{activationUrl}": activateURL,
|
||||
"{siteUrl}": options["siteURL"],
|
||||
"{siteSecTitle}": options["siteTitle"],
|
||||
}
|
||||
return fmt.Sprintf("【%s】注册激活", options["siteName"]),
|
||||
util.Replace(replace, options["mail_activation_template"])
|
||||
type CommonContext struct {
|
||||
SiteBasic *setting.SiteBasic
|
||||
Logo *setting.Logo
|
||||
SiteUrl string
|
||||
}
|
||||
|
||||
// NewResetEmail 新建重设密码邮件
|
||||
func NewResetEmail(userName, resetURL string) (string, string) {
|
||||
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_reset_pwd_template")
|
||||
replace := map[string]string{
|
||||
"{siteTitle}": options["siteName"],
|
||||
"{userName}": userName,
|
||||
"{resetUrl}": resetURL,
|
||||
"{siteUrl}": options["siteURL"],
|
||||
"{siteSecTitle}": options["siteTitle"],
|
||||
}
|
||||
return fmt.Sprintf("【%s】密码重置", options["siteName"]),
|
||||
util.Replace(replace, options["mail_reset_pwd_template"])
|
||||
// ResetContext used for variables in reset email
|
||||
type ResetContext struct {
|
||||
*CommonContext
|
||||
User *ent.User
|
||||
Url string
|
||||
}
|
||||
|
||||
// NewResetEmail generates reset email from template
|
||||
func NewResetEmail(ctx context.Context, settings setting.Provider, user *ent.User, url string) (string, string, error) {
|
||||
templates := settings.ResetEmailTemplate(ctx)
|
||||
if len(templates) == 0 {
|
||||
return "", "", fmt.Errorf("reset email template not configured")
|
||||
}
|
||||
|
||||
selected := selectTemplate(templates, user)
|
||||
resetCtx := ResetContext{
|
||||
CommonContext: commonContext(ctx, settings),
|
||||
User: user,
|
||||
Url: url,
|
||||
}
|
||||
|
||||
tmpl, err := template.New("reset").Parse(selected.Body)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse email template: %w", err)
|
||||
}
|
||||
|
||||
var res strings.Builder
|
||||
err = tmpl.Execute(&res, resetCtx)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to execute email template: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[%s] %s", resetCtx.SiteBasic.Name, selected.Title), res.String(), nil
|
||||
}
|
||||
|
||||
// ActivationContext used for variables in activation email
|
||||
type ActivationContext struct {
|
||||
*CommonContext
|
||||
User *ent.User
|
||||
Url string
|
||||
}
|
||||
|
||||
// NewActivationEmail generates activation email from template
|
||||
func NewActivationEmail(ctx context.Context, settings setting.Provider, user *ent.User, url string) (string, string, error) {
|
||||
templates := settings.ActivationEmailTemplate(ctx)
|
||||
if len(templates) == 0 {
|
||||
return "", "", fmt.Errorf("activation email template not configured")
|
||||
}
|
||||
|
||||
selected := selectTemplate(templates, user)
|
||||
activationCtx := ActivationContext{
|
||||
CommonContext: commonContext(ctx, settings),
|
||||
User: user,
|
||||
Url: url,
|
||||
}
|
||||
|
||||
tmpl, err := template.New("activation").Parse(selected.Body)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse email template: %w", err)
|
||||
}
|
||||
|
||||
var res strings.Builder
|
||||
err = tmpl.Execute(&res, activationCtx)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to execute email template: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[%s] %s", activationCtx.SiteBasic.Name, selected.Title), res.String(), nil
|
||||
}
|
||||
|
||||
func commonContext(ctx context.Context, settings setting.Provider) *CommonContext {
|
||||
logo := settings.Logo(ctx)
|
||||
siteUrl := settings.SiteURL(ctx)
|
||||
res := &CommonContext{
|
||||
SiteBasic: settings.SiteBasic(ctx),
|
||||
Logo: settings.Logo(ctx),
|
||||
SiteUrl: siteUrl.String(),
|
||||
}
|
||||
|
||||
// Add site url if logo is not an url
|
||||
if !strings.HasPrefix(logo.Light, "http") {
|
||||
logoPath, _ := url.Parse(logo.Light)
|
||||
res.Logo.Light = siteUrl.ResolveReference(logoPath).String()
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(logo.Normal, "http") {
|
||||
logoPath, _ := url.Parse(logo.Normal)
|
||||
res.Logo.Normal = siteUrl.ResolveReference(logoPath).String()
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func selectTemplate(templates []setting.EmailTemplate, u *ent.User) setting.EmailTemplate {
|
||||
selected := templates[0]
|
||||
if u != nil {
|
||||
for _, t := range templates {
|
||||
if strings.EqualFold(t.Language, u.Settings.Language) {
|
||||
selected = t
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package backoff
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -3,10 +3,11 @@ package chunk
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
@@ -18,36 +19,38 @@ type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error
|
||||
|
||||
// ChunkGroup manage groups of chunks
|
||||
type ChunkGroup struct {
|
||||
file fsctx.FileHeader
|
||||
chunkSize uint64
|
||||
file *fs.UploadRequest
|
||||
chunkSize int64
|
||||
backoff backoff.Backoff
|
||||
enableRetryBuffer bool
|
||||
l logging.Logger
|
||||
|
||||
fileInfo *fsctx.UploadTaskInfo
|
||||
currentIndex int
|
||||
chunkNum uint64
|
||||
chunkNum int64
|
||||
bufferTemp *os.File
|
||||
tempPath string
|
||||
}
|
||||
|
||||
func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff, useBuffer bool) *ChunkGroup {
|
||||
func NewChunkGroup(file *fs.UploadRequest, chunkSize int64, backoff backoff.Backoff, useBuffer bool, l logging.Logger, tempPath string) *ChunkGroup {
|
||||
c := &ChunkGroup{
|
||||
file: file,
|
||||
chunkSize: chunkSize,
|
||||
backoff: backoff,
|
||||
fileInfo: file.Info(),
|
||||
currentIndex: -1,
|
||||
enableRetryBuffer: useBuffer,
|
||||
l: l,
|
||||
tempPath: tempPath,
|
||||
}
|
||||
|
||||
if c.chunkSize == 0 {
|
||||
c.chunkSize = c.fileInfo.Size
|
||||
c.chunkSize = c.file.Props.Size
|
||||
}
|
||||
|
||||
if c.fileInfo.Size == 0 {
|
||||
if c.file.Props.Size == 0 {
|
||||
c.chunkNum = 1
|
||||
} else {
|
||||
c.chunkNum = c.fileInfo.Size / c.chunkSize
|
||||
if c.fileInfo.Size%c.chunkSize != 0 {
|
||||
c.chunkNum = c.file.Props.Size / c.chunkSize
|
||||
if c.file.Props.Size%c.chunkSize != 0 {
|
||||
c.chunkNum++
|
||||
}
|
||||
}
|
||||
@@ -71,7 +74,7 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
|
||||
|
||||
// If useBuffer is enabled, tee the reader to a temp file
|
||||
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
|
||||
c.bufferTemp, _ = os.CreateTemp("", bufferTempPattern)
|
||||
c.bufferTemp, _ = os.CreateTemp(util.DataPath(c.tempPath), bufferTempPattern)
|
||||
reader = io.TeeReader(reader, c.bufferTemp)
|
||||
}
|
||||
|
||||
@@ -90,7 +93,7 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
|
||||
return fmt.Errorf("failed to seek temp file back to chunk start: %w", err)
|
||||
}
|
||||
|
||||
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
|
||||
c.l.Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
|
||||
reader = io.NopCloser(c.bufferTemp)
|
||||
}
|
||||
}
|
||||
@@ -108,25 +111,25 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
|
||||
}
|
||||
}
|
||||
|
||||
util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
|
||||
c.l.Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
|
||||
return c.Process(processor)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
util.Log().Debug("Chunk %d processed", c.currentIndex)
|
||||
c.l.Debug("Chunk %d processed", c.currentIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start returns the byte index of current chunk
|
||||
func (c *ChunkGroup) Start() int64 {
|
||||
return int64(uint64(c.Index()) * c.chunkSize)
|
||||
return int64(int64(c.Index()) * c.chunkSize)
|
||||
}
|
||||
|
||||
// Total returns the total length
|
||||
func (c *ChunkGroup) Total() int64 {
|
||||
return int64(c.fileInfo.Size)
|
||||
return int64(c.file.Props.Size)
|
||||
}
|
||||
|
||||
// Num returns the total chunk number
|
||||
@@ -155,7 +158,7 @@ func (c *ChunkGroup) Next() bool {
|
||||
func (c *ChunkGroup) Length() int64 {
|
||||
contentLength := c.chunkSize
|
||||
if c.Index() == int(c.chunkNum-1) {
|
||||
contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1)
|
||||
contentLength = c.file.Props.Size - c.chunkSize*(c.chunkNum-1)
|
||||
}
|
||||
|
||||
return int64(contentLength)
|
||||
588
pkg/filemanager/driver/cos/cos.go
Normal file
588
pkg/filemanager/driver/cos/cos.go
Normal file
@@ -0,0 +1,588 @@
|
||||
package cos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/google/go-querystring/query"
|
||||
"github.com/samber/lo"
|
||||
cossdk "github.com/tencentyun/cos-go-sdk-v5"
|
||||
)
|
||||
|
||||
// UploadPolicy 腾讯云COS上传策略
|
||||
type UploadPolicy struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions []interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
// MetaData 文件元信息
|
||||
type MetaData struct {
|
||||
Size uint64
|
||||
CallbackKey string
|
||||
CallbackURL string
|
||||
}
|
||||
|
||||
type urlOption struct {
|
||||
Speed int64 `url:"x-cos-traffic-limit,omitempty"`
|
||||
ContentDescription string `url:"response-content-disposition,omitempty"`
|
||||
Exif *string `url:"exif,omitempty"`
|
||||
CiProcess string `url:"ci-process,omitempty"`
|
||||
}
|
||||
|
||||
type (
|
||||
CosParts struct {
|
||||
ETag string
|
||||
PartNumber int
|
||||
}
|
||||
)
|
||||
|
||||
// Driver 腾讯云COS适配器模板
|
||||
type Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
client *cossdk.Client
|
||||
settings setting.Provider
|
||||
config conf.ConfigProvider
|
||||
httpClient request.Client
|
||||
l logging.Logger
|
||||
mime mime.MimeDetector
|
||||
|
||||
chunkSize int64
|
||||
}
|
||||
|
||||
const (
|
||||
// MultiPartUploadThreshold 服务端使用分片上传的阈值
|
||||
MultiPartUploadThreshold int64 = 5 * (1 << 30) // 5GB
|
||||
|
||||
maxDeleteBatch = 1000
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
overwriteOptionHeader = "x-cos-forbid-overwrite"
|
||||
partNumberParam = "partNumber"
|
||||
uploadIdParam = "uploadId"
|
||||
contentTypeHeader = "Content-Type"
|
||||
contentLengthHeader = "Content-Length"
|
||||
)
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
cossdk.SetNeedSignHeaders("host", false)
|
||||
cossdk.SetNeedSignHeaders("origin", false)
|
||||
boolset.Sets(map[driver.HandlerCapability]bool{
|
||||
driver.HandlerCapabilityUploadSentinelRequired: true,
|
||||
}, features)
|
||||
}
|
||||
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) {
|
||||
chunkSize := policy.Settings.ChunkSize
|
||||
if policy.Settings.ChunkSize == 0 {
|
||||
chunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
driver := &Driver{
|
||||
policy: policy,
|
||||
settings: settings,
|
||||
chunkSize: chunkSize,
|
||||
config: config,
|
||||
l: l,
|
||||
mime: mime,
|
||||
httpClient: request.NewClient(config, request.WithLogger(l)),
|
||||
}
|
||||
|
||||
u, err := url.Parse(policy.Server)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse COS bucket server url: %w", err)
|
||||
}
|
||||
driver.client = cossdk.NewClient(&cossdk.BaseURL{BucketURL: u}, &http.Client{
|
||||
Transport: &cossdk.AuthorizationTransport{
|
||||
SecretID: policy.AccessKey,
|
||||
SecretKey: policy.SecretKey,
|
||||
},
|
||||
})
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
//
|
||||
//// List 列出COS文件
|
||||
//func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// // 初始化列目录参数
|
||||
// opt := &cossdk.BucketGetOptions{
|
||||
// Prefix: strings.TrimPrefix(base, "/"),
|
||||
// EncodingType: "",
|
||||
// MaxKeys: 1000,
|
||||
// }
|
||||
// // 是否为递归列出
|
||||
// if !recursive {
|
||||
// opt.Delimiter = "/"
|
||||
// }
|
||||
// // 手动补齐结尾的slash
|
||||
// if opt.Prefix != "" {
|
||||
// opt.Prefix += "/"
|
||||
// }
|
||||
//
|
||||
// var (
|
||||
// marker string
|
||||
// objects []cossdk.Object
|
||||
// commons []string
|
||||
// )
|
||||
//
|
||||
// for {
|
||||
// res, _, err := handler.client.Bucket.Get(ctx, opt)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// objects = append(objects, res.Contents...)
|
||||
// commons = append(commons, res.CommonPrefixes...)
|
||||
// // 如果本次未列取完,则继续使用marker获取结果
|
||||
// marker = res.NextMarker
|
||||
// // marker 为空时结果列取完毕,跳出
|
||||
// if marker == "" {
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 处理列取结果
|
||||
// res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
// // 处理目录
|
||||
// for _, object := range commons {
|
||||
// rel, err := filepath.Rel(opt.Prefix, object)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object),
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: 0,
|
||||
// IsDir: true,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
// // 处理文件
|
||||
// for _, object := range objects {
|
||||
// rel, err := filepath.Rel(opt.Prefix, object.Key)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object.Key),
|
||||
// Source: object.Key,
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: uint64(object.Size),
|
||||
// IsDir: false,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//
|
||||
//}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (handler Driver) CORS() error {
|
||||
_, err := handler.client.Bucket.PutCORS(context.Background(), &cossdk.BucketPutCORSOptions{
|
||||
Rules: []cossdk.BucketCORSRule{{
|
||||
AllowedMethods: []string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
},
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedHeaders: []string{"*"},
|
||||
MaxAgeSeconds: 3600,
|
||||
ExposeHeaders: []string{"ETag"},
|
||||
}},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
opt := &cossdk.ObjectPutHeaderOptions{
|
||||
ContentType: mimeType,
|
||||
XOptionHeader: &http.Header{
|
||||
overwriteOptionHeader: []string{fmt.Sprintf("%t", overwrite)},
|
||||
},
|
||||
}
|
||||
|
||||
// 小文件直接上传
|
||||
if file.Props.Size < MultiPartUploadThreshold {
|
||||
_, err := handler.client.Object.Put(ctx, file.Props.SavePath, file, &cossdk.ObjectPutOptions{
|
||||
ObjectPutHeaderOptions: opt,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
imur, _, err := handler.client.Object.InitiateMultipartUpload(ctx, file.Props.SavePath, &cossdk.InitiateMultipartUploadOptions{
|
||||
ObjectPutHeaderOptions: opt,
|
||||
})
|
||||
|
||||
chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{
|
||||
Max: handler.settings.ChunkRetryLimit(ctx),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, handler.settings.UseChunkBuffer(ctx), handler.l, handler.settings.TempPath(ctx))
|
||||
|
||||
parts := make([]CosParts, 0, chunks.Num())
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
res, err := handler.client.Object.UploadPart(ctx, file.Props.SavePath, imur.UploadID, current.Index()+1, content, &cossdk.ObjectUploadPartOptions{
|
||||
ContentLength: current.Length(),
|
||||
})
|
||||
if err == nil {
|
||||
parts = append(parts, CosParts{
|
||||
ETag: res.Header.Get("ETag"),
|
||||
PartNumber: current.Index() + 1,
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
handler.cancelUpload(file.Props.SavePath, imur.UploadID)
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
_, _, err = handler.client.Object.CompleteMultipartUpload(ctx, file.Props.SavePath, imur.UploadID, &cossdk.CompleteMultipartUploadOptions{
|
||||
Parts: lo.Map(parts, func(v CosParts, i int) cossdk.Object {
|
||||
return cossdk.Object{
|
||||
ETag: v.ETag,
|
||||
PartNumber: v.PartNumber,
|
||||
}
|
||||
}),
|
||||
XOptionHeader: &http.Header{
|
||||
overwriteOptionHeader: []string{fmt.Sprintf("%t", overwrite)},
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
handler.cancelUpload(file.Props.SavePath, imur.UploadID)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
groups := lo.Chunk(files, maxDeleteBatch)
|
||||
failed := make([]string, 0)
|
||||
var lastError error
|
||||
for index, group := range groups {
|
||||
handler.l.Debug("Process delete group #%d: %v", index, group)
|
||||
res, _, err := handler.client.Object.DeleteMulti(ctx,
|
||||
&cossdk.ObjectDeleteMultiOptions{
|
||||
Objects: lo.Map(group, func(item string, index int) cossdk.Object {
|
||||
return cossdk.Object{Key: item}
|
||||
}),
|
||||
Quiet: true,
|
||||
})
|
||||
if err != nil {
|
||||
lastError = err
|
||||
failed = append(failed, group...)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range res.Errors {
|
||||
handler.l.Debug("Failed to delete file: %s, Code:%s, Message:%s", v.Key, v.Code, v.Key)
|
||||
failed = append(failed, v.Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(failed) > 0 && lastError == nil {
|
||||
lastError = fmt.Errorf("failed to delete files: %v", failed)
|
||||
}
|
||||
|
||||
return failed, lastError
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
w, h := handler.settings.ThumbSize(ctx)
|
||||
thumbParam := fmt.Sprintf("imageMogr2/thumbnail/%dx%d", w, h)
|
||||
|
||||
source, err := handler.signSourceURL(
|
||||
ctx,
|
||||
e.Source(),
|
||||
expire,
|
||||
&urlOption{},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
thumbURL, _ := url.Parse(source)
|
||||
thumbQuery := thumbURL.Query()
|
||||
thumbQuery.Add(thumbParam, "")
|
||||
thumbURL.RawQuery = thumbQuery.Encode()
|
||||
|
||||
return thumbURL.String(), nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
// 添加各项设置
|
||||
options := urlOption{}
|
||||
if args.Speed > 0 {
|
||||
if args.Speed < 819200 {
|
||||
args.Speed = 819200
|
||||
}
|
||||
if args.Speed > 838860800 {
|
||||
args.Speed = 838860800
|
||||
}
|
||||
options.Speed = args.Speed
|
||||
}
|
||||
if args.IsDownload {
|
||||
encodedFilename := url.PathEscape(args.DisplayName)
|
||||
options.ContentDescription = fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`,
|
||||
encodedFilename, encodedFilename)
|
||||
}
|
||||
|
||||
return handler.signSourceURL(ctx, e.Source(), args.Expire, &options)
|
||||
}
|
||||
|
||||
func (handler Driver) signSourceURL(ctx context.Context, path string, expire *time.Time, options *urlOption) (string, error) {
|
||||
// 公有空间不需要签名
|
||||
if !handler.policy.IsPrivate || (handler.policy.Settings.SourceAuth && handler.policy.Settings.CustomProxy) {
|
||||
file, err := url.Parse(handler.policy.Server)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
file.Path = path
|
||||
|
||||
// 非签名URL不支持设置响应header
|
||||
options.ContentDescription = ""
|
||||
|
||||
optionQuery, err := query.Values(*options)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
file.RawQuery = optionQuery.Encode()
|
||||
|
||||
return file.String(), nil
|
||||
}
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if expire != nil {
|
||||
ttl = time.Until(*expire)
|
||||
} else {
|
||||
// 20 years for permanent link
|
||||
ttl = time.Duration(24) * time.Hour * 365 * 20
|
||||
}
|
||||
|
||||
presignedURL, err := handler.client.Object.GetPresignedURL(ctx, http.MethodGet, path,
|
||||
handler.policy.AccessKey, handler.policy.SecretKey, ttl, options)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return presignedURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
// 生成回调地址
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
// 在从机端创建上传会话
|
||||
uploadSession.ChunkSize = handler.chunkSize
|
||||
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeCos, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
// 初始化分片上传
|
||||
opt := &cossdk.ObjectPutHeaderOptions{
|
||||
ContentType: mimeType,
|
||||
XOptionHeader: &http.Header{
|
||||
overwriteOptionHeader: []string{"true"},
|
||||
},
|
||||
}
|
||||
|
||||
imur, _, err := handler.client.Object.InitiateMultipartUpload(ctx, file.Props.SavePath, &cossdk.InitiateMultipartUploadOptions{
|
||||
ObjectPutHeaderOptions: opt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize multipart upload: %w", err)
|
||||
}
|
||||
uploadSession.UploadID = imur.UploadID
|
||||
|
||||
// 为每个分片签名上传 URL
|
||||
chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{}, false, handler.l, "")
|
||||
urls := make([]string, chunks.Num())
|
||||
ttl := time.Until(uploadSession.Props.ExpireAt)
|
||||
for chunks.Next() {
|
||||
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
|
||||
signedURL, err := handler.client.Object.GetPresignedURL(
|
||||
ctx,
|
||||
http.MethodPut,
|
||||
file.Props.SavePath,
|
||||
handler.policy.AccessKey,
|
||||
handler.policy.SecretKey,
|
||||
ttl,
|
||||
&cossdk.PresignedURLOptions{
|
||||
Query: &url.Values{
|
||||
partNumberParam: []string{fmt.Sprintf("%d", c.Index()+1)},
|
||||
uploadIdParam: []string{imur.UploadID},
|
||||
},
|
||||
Header: &http.Header{
|
||||
contentTypeHeader: []string{"application/octet-stream"},
|
||||
contentLengthHeader: []string{fmt.Sprintf("%d", c.Length())},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
urls[c.Index()] = signedURL.String()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 签名完成分片上传的URL
|
||||
completeURL, err := handler.client.Object.GetPresignedURL(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
file.Props.SavePath,
|
||||
handler.policy.AccessKey,
|
||||
handler.policy.SecretKey,
|
||||
time.Until(uploadSession.Props.ExpireAt),
|
||||
&cossdk.PresignedURLOptions{
|
||||
Query: &url.Values{
|
||||
uploadIdParam: []string{imur.UploadID},
|
||||
},
|
||||
Header: &http.Header{
|
||||
overwriteOptionHeader: []string{"true"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fs.UploadCredential{
|
||||
UploadID: imur.UploadID,
|
||||
UploadURLs: urls,
|
||||
CompleteURL: completeURL.String(),
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: handler.chunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
_, err := handler.client.Object.AbortMultipartUpload(ctx, uploadSession.Props.SavePath, uploadSession.UploadID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
if session.SentinelTaskID == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure uploaded file size is correct
|
||||
res, err := handler.client.Object.Head(ctx, session.Props.SavePath, &cossdk.ObjectHeadOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get uploaded file size: %w", err)
|
||||
}
|
||||
|
||||
if res.ContentLength != session.Props.Size {
|
||||
return serializer.NewError(
|
||||
serializer.CodeMetaMismatch,
|
||||
fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.ContentLength),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
mediaMetaExts := handler.policy.Settings.MediaMetaExts
|
||||
if !handler.policy.Settings.NativeMediaProcessing {
|
||||
mediaMetaExts = nil
|
||||
}
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaSupportedExts: mediaMetaExts,
|
||||
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbSupportedExts: handler.policy.Settings.ThumbExts,
|
||||
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
|
||||
ThumbMaxSize: handler.policy.Settings.ThumbMaxSize,
|
||||
ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts,
|
||||
}
|
||||
}
|
||||
|
||||
// Meta 获取文件信息
|
||||
func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
|
||||
res, err := handler.client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MetaData{
|
||||
Size: uint64(res.ContentLength),
|
||||
CallbackKey: res.Header.Get("x-cos-meta-key"),
|
||||
CallbackURL: res.Header.Get("x-cos-meta-callback"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
if util.ContainsString(supportedImageExt, ext) {
|
||||
return handler.extractImageMeta(ctx, path)
|
||||
}
|
||||
|
||||
return handler.extractStreamMeta(ctx, path)
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (handler *Driver) cancelUpload(path, uploadId string) {
|
||||
if _, err := handler.client.Object.AbortMultipartUpload(context.Background(), path, uploadId); err != nil {
|
||||
handler.l.Warning("failed to abort multipart upload: %s", err)
|
||||
}
|
||||
}
|
||||
294
pkg/filemanager/driver/cos/media.go
Normal file
294
pkg/filemanager/driver/cos/media.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package cos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/mediameta"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/samber/lo"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
mediaInfoTTL = time.Duration(10) * time.Minute
|
||||
videoInfo = "videoinfo"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedImageExt = []string{"jpg", "jpeg", "png", "gif", "bmp", "webp", "tiff", "heic", "heif"}
|
||||
)
|
||||
|
||||
type (
|
||||
ImageProp struct {
|
||||
Value string `json:"val"`
|
||||
}
|
||||
ImageInfo map[string]ImageProp
|
||||
Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
Video struct {
|
||||
Index int `xml:"Index"`
|
||||
CodecName string `xml:"CodecName"`
|
||||
CodecLongName string `xml:"CodecLongName"`
|
||||
CodecTimeBase string `xml:"CodecTimeBase"`
|
||||
CodecTagString string `xml:"CodecTagString"`
|
||||
CodecTag string `xml:"CodecTag"`
|
||||
ColorPrimaries string `xml:"ColorPrimaries"`
|
||||
ColorRange string `xml:"ColorRange"`
|
||||
ColorTransfer string `xml:"ColorTransfer"`
|
||||
Profile string `xml:"Profile"`
|
||||
Width int `xml:"Width"`
|
||||
Height int `xml:"Height"`
|
||||
HasBFrame string `xml:"HasBFrame"`
|
||||
RefFrames string `xml:"RefFrames"`
|
||||
Sar string `xml:"Sar"`
|
||||
Dar string `xml:"Dar"`
|
||||
PixFormat string `xml:"PixFormat"`
|
||||
FieldOrder string `xml:"FieldOrder"`
|
||||
Level string `xml:"Level"`
|
||||
Fps string `xml:"Fps"`
|
||||
AvgFps string `xml:"AvgFps"`
|
||||
Timebase string `xml:"Timebase"`
|
||||
StartTime string `xml:"StartTime"`
|
||||
Duration string `xml:"Duration"`
|
||||
Bitrate string `xml:"Bitrate"`
|
||||
NumFrames string `xml:"NumFrames"`
|
||||
Language string `xml:"Language"`
|
||||
}
|
||||
Audio struct {
|
||||
Index int `xml:"Index"`
|
||||
CodecName string `xml:"CodecName"`
|
||||
CodecLongName string `xml:"CodecLongName"`
|
||||
CodecTimeBase string `xml:"CodecTimeBase"`
|
||||
CodecTagString string `xml:"CodecTagString"`
|
||||
CodecTag string `xml:"CodecTag"`
|
||||
SampleFmt string `xml:"SampleFmt"`
|
||||
SampleRate string `xml:"SampleRate"`
|
||||
Channel string `xml:"Channel"`
|
||||
ChannelLayout string `xml:"ChannelLayout"`
|
||||
Timebase string `xml:"Timebase"`
|
||||
StartTime string `xml:"StartTime"`
|
||||
Duration string `xml:"Duration"`
|
||||
Bitrate string `xml:"Bitrate"`
|
||||
Language string `xml:"Language"`
|
||||
}
|
||||
Subtitle struct {
|
||||
Index string `xml:"Index"`
|
||||
Language string `xml:"Language"`
|
||||
}
|
||||
Response struct {
|
||||
XMLName xml.Name `xml:"Response"`
|
||||
MediaInfo struct {
|
||||
Stream struct {
|
||||
Video []Video `xml:"Video"`
|
||||
Audio []Audio `xml:"Audio"`
|
||||
Subtitle []Subtitle `xml:"Subtitle"`
|
||||
} `xml:"Stream"`
|
||||
Format struct {
|
||||
NumStream string `xml:"NumStream"`
|
||||
NumProgram string `xml:"NumProgram"`
|
||||
FormatName string `xml:"FormatName"`
|
||||
FormatLongName string `xml:"FormatLongName"`
|
||||
StartTime string `xml:"StartTime"`
|
||||
Duration string `xml:"Duration"`
|
||||
Bitrate string `xml:"Bitrate"`
|
||||
Size string `xml:"Size"`
|
||||
} `xml:"Format"`
|
||||
} `xml:"MediaInfo"`
|
||||
}
|
||||
)
|
||||
|
||||
func (handler *Driver) extractStreamMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) {
|
||||
resp, err := handler.extractMediaInfo(ctx, path, &urlOption{CiProcess: videoInfo})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var info Response
|
||||
if err := xml.Unmarshal([]byte(resp), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
streams := lo.Map(info.MediaInfo.Stream.Video, func(stream Video, index int) mediameta.Stream {
|
||||
return mediameta.Stream{
|
||||
Index: stream.Index,
|
||||
CodecName: stream.CodecName,
|
||||
CodecLongName: stream.CodecLongName,
|
||||
CodecType: "video",
|
||||
Width: stream.Width,
|
||||
Height: stream.Height,
|
||||
Bitrate: stream.Bitrate,
|
||||
}
|
||||
})
|
||||
streams = append(streams, lo.Map(info.MediaInfo.Stream.Audio, func(stream Audio, index int) mediameta.Stream {
|
||||
return mediameta.Stream{
|
||||
Index: stream.Index,
|
||||
CodecName: stream.CodecName,
|
||||
CodecLongName: stream.CodecLongName,
|
||||
CodecType: "audio",
|
||||
Bitrate: stream.Bitrate,
|
||||
}
|
||||
})...)
|
||||
|
||||
metas := make([]driver.MediaMeta, 0)
|
||||
metas = append(metas, mediameta.ProbeMetaTransform(&mediameta.FFProbeMeta{
|
||||
Format: &mediameta.Format{
|
||||
FormatName: info.MediaInfo.Format.FormatName,
|
||||
FormatLongName: info.MediaInfo.Format.FormatLongName,
|
||||
Duration: info.MediaInfo.Format.Duration,
|
||||
Bitrate: info.MediaInfo.Format.Bitrate,
|
||||
},
|
||||
Streams: streams,
|
||||
})...)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) {
|
||||
exif := ""
|
||||
resp, err := handler.extractMediaInfo(ctx, path, &urlOption{
|
||||
Exif: &exif,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var imageInfo ImageInfo
|
||||
if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
metas := make([]driver.MediaMeta, 0)
|
||||
exifMap := lo.MapEntries(imageInfo, func(key string, value ImageProp) (string, string) {
|
||||
return key, value.Value
|
||||
})
|
||||
metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...)
|
||||
metas = append(metas, parseGpsInfo(imageInfo)...)
|
||||
for i := 0; i < len(metas); i++ {
|
||||
metas[i].Type = driver.MetaTypeExif
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
// extractMediaInfo Sends API calls to COS service to extract media info.
|
||||
func (handler *Driver) extractMediaInfo(ctx context.Context, path string, opt *urlOption) (string, error) {
|
||||
mediaInfoExpire := time.Now().Add(mediaInfoTTL)
|
||||
thumbURL, err := handler.signSourceURL(
|
||||
ctx,
|
||||
path,
|
||||
&mediaInfoExpire,
|
||||
opt,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign media info url: %w", err)
|
||||
}
|
||||
|
||||
resp, err := handler.httpClient.
|
||||
Request(http.MethodGet, thumbURL, nil, request.WithContext(ctx)).
|
||||
CheckHTTPResponse(http.StatusOK).
|
||||
GetResponseIgnoreErr()
|
||||
if err != nil {
|
||||
return "", handleCosError(resp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseGpsInfo(imageInfo ImageInfo) []driver.MediaMeta {
|
||||
latitude := imageInfo["GPSLatitude"] // 31deg 16.26808'
|
||||
longitude := imageInfo["GPSLongitude"] // 120deg 42.91039'
|
||||
latRef := imageInfo["GPSLatitudeRef"] // North
|
||||
lonRef := imageInfo["GPSLongitudeRef"] // East
|
||||
|
||||
// Make sure all value exist in map
|
||||
if latitude.Value == "" || longitude.Value == "" || latRef.Value == "" || lonRef.Value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lat := parseRawGPS(latitude.Value, latRef.Value)
|
||||
lon := parseRawGPS(longitude.Value, lonRef.Value)
|
||||
if !math.IsNaN(lat) && !math.IsNaN(lon) {
|
||||
lat, lng := mediameta.NormalizeGPS(lat, lon)
|
||||
return []driver.MediaMeta{{
|
||||
Key: mediameta.GpsLat,
|
||||
Value: fmt.Sprintf("%f", lat),
|
||||
}, {
|
||||
Key: mediameta.GpsLng,
|
||||
Value: fmt.Sprintf("%f", lng),
|
||||
}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRawGPS(gpsStr string, ref string) float64 {
|
||||
elem := strings.Split(gpsStr, " ")
|
||||
if len(elem) < 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var (
|
||||
deg float64
|
||||
minutes float64
|
||||
seconds float64
|
||||
)
|
||||
|
||||
deg = getGpsElemValue(elem[0])
|
||||
if len(elem) >= 2 {
|
||||
minutes = getGpsElemValue(elem[1])
|
||||
}
|
||||
if len(elem) >= 3 {
|
||||
seconds = getGpsElemValue(elem[2])
|
||||
}
|
||||
|
||||
decimal := deg + minutes/60.0 + seconds/3600.0
|
||||
|
||||
if ref == "S" || ref == "W" {
|
||||
return -decimal
|
||||
}
|
||||
|
||||
return decimal
|
||||
}
|
||||
|
||||
func getGpsElemValue(elm string) float64 {
|
||||
elements := strings.Split(elm, "/")
|
||||
if len(elements) != 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
numerator, err := strconv.ParseFloat(elements[0], 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
denominator, err := strconv.ParseFloat(elements[1], 64)
|
||||
if err != nil || denominator == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return numerator / denominator
|
||||
}
|
||||
|
||||
func handleCosError(resp string, originErr error) error {
|
||||
if resp == "" {
|
||||
return originErr
|
||||
}
|
||||
|
||||
var err Error
|
||||
if err := xml.Unmarshal([]byte(resp), &err); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal cos error: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cos error: %s", err.Message)
|
||||
}
|
||||
118
pkg/filemanager/driver/cos/scf.go
Normal file
118
pkg/filemanager/driver/cos/scf.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package cos
|
||||
|
||||
// TODO: revisit para error
|
||||
const scfFunc = `# -*- coding: utf8 -*-
|
||||
# SCF配置COS触发,向 Cloudreve 发送回调
|
||||
from qcloud_cos_v5 import CosConfig
|
||||
from qcloud_cos_v5 import CosS3Client
|
||||
from qcloud_cos_v5 import CosServiceError
|
||||
from qcloud_cos_v5 import CosClientError
|
||||
import sys
|
||||
import logging
|
||||
import requests
|
||||
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
logging = logging.getLogger()
|
||||
|
||||
|
||||
def main_handler(event, context):
|
||||
logging.info("start main handler")
|
||||
for record in event['Records']:
|
||||
try:
|
||||
if "x-cos-meta-callback" not in record['cos']['cosObject']['meta']:
|
||||
logging.info("Cannot find callback URL, skiped.")
|
||||
return 'Success'
|
||||
callback = record['cos']['cosObject']['meta']['x-cos-meta-callback']
|
||||
key = record['cos']['cosObject']['key']
|
||||
logging.info("Callback URL is " + callback)
|
||||
|
||||
r = requests.get(callback)
|
||||
print(r.text)
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('Error getting object {} callback url. '.format(key))
|
||||
raise e
|
||||
return "Fail"
|
||||
|
||||
return "Success"
|
||||
`
|
||||
|
||||
//
|
||||
//// CreateSCF 创建回调云函数
|
||||
//func CreateSCF(policy *model.Policy, region string) error {
|
||||
// // 初始化客户端
|
||||
// credential := common.NewCredential(
|
||||
// policy.AccessKey,
|
||||
// policy.SecretKey,
|
||||
// )
|
||||
// cpf := profile.NewClientProfile()
|
||||
// client, err := scf.NewClient(credential, region, cpf)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// // 创建回调代码数据
|
||||
// buff := &bytes.Buffer{}
|
||||
// bs64 := base64.NewEncoder(base64.StdEncoding, buff)
|
||||
// zipWriter := zip.NewWriter(bs64)
|
||||
// header := zip.FileHeader{
|
||||
// Name: "callback.py",
|
||||
// Method: zip.Deflate,
|
||||
// }
|
||||
// writer, err := zipWriter.CreateHeader(&header)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// _, err = io.Copy(writer, strings.NewReader(scfFunc))
|
||||
// zipWriter.Close()
|
||||
//
|
||||
// // 创建云函数
|
||||
// req := scf.NewCreateFunctionRequest()
|
||||
// funcName := "cloudreve_" + hashid.HashID(policy.ID, hashid.PolicyID) + strconv.FormatInt(time.Now().Unix(), 10)
|
||||
// zipFileBytes, _ := ioutil.ReadAll(buff)
|
||||
// zipFileStr := string(zipFileBytes)
|
||||
// codeSource := "ZipFile"
|
||||
// handler := "callback.main_handler"
|
||||
// desc := "Cloudreve 用回调函数"
|
||||
// timeout := int64(60)
|
||||
// runtime := "Python3.6"
|
||||
// req.FunctionName = &funcName
|
||||
// req.Code = &scf.Code{
|
||||
// ZipFile: &zipFileStr,
|
||||
// }
|
||||
// req.Handler = &handler
|
||||
// req.Description = &desc
|
||||
// req.Timeout = &timeout
|
||||
// req.Runtime = &runtime
|
||||
// req.CodeSource = &codeSource
|
||||
//
|
||||
// _, err = client.CreateFunction(req)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// time.Sleep(time.Duration(5) * time.Second)
|
||||
//
|
||||
// // 创建触发器
|
||||
// server, _ := url.Parse(policy.Server)
|
||||
// triggerType := "cos"
|
||||
// triggerDesc := `{"event":"cos:ObjectCreated:Post","filter":{"Prefix":"","Suffix":""}}`
|
||||
// enable := "OPEN"
|
||||
//
|
||||
// trigger := scf.NewCreateTriggerRequest()
|
||||
// trigger.FunctionName = &funcName
|
||||
// trigger.TriggerName = &server.Host
|
||||
// trigger.Type = &triggerType
|
||||
// trigger.TriggerDesc = &triggerDesc
|
||||
// trigger.Enable = &enable
|
||||
//
|
||||
// _, err = client.CreateTrigger(trigger)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
//}
|
||||
122
pkg/filemanager/driver/handler.go
Normal file
122
pkg/filemanager/driver/handler.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
)
|
||||
|
||||
const (
|
||||
// HandlerCapabilityProxyRequired this handler requires Cloudreve's proxy to get file content
|
||||
HandlerCapabilityProxyRequired HandlerCapability = iota
|
||||
// HandlerCapabilityInboundGet this handler supports directly get file's RSCloser, usually
|
||||
// indicates that the file is stored in the same machine as Cloudreve
|
||||
HandlerCapabilityInboundGet
|
||||
// HandlerCapabilityUploadSentinelRequired this handler does not support compliance callback mechanism,
|
||||
// thus it requires Cloudreve's sentinel to guarantee the upload is under control. Cloudreve will try
|
||||
// to delete the placeholder file and cancel the upload session if upload callback is not made after upload
|
||||
// session expire.
|
||||
HandlerCapabilityUploadSentinelRequired
|
||||
)
|
||||
|
||||
type (
|
||||
MetaType string
|
||||
MediaMeta struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
Type MetaType `json:"type"`
|
||||
}
|
||||
|
||||
HandlerCapability int
|
||||
|
||||
GetSourceArgs struct {
|
||||
Expire *time.Time
|
||||
IsDownload bool
|
||||
Speed int64
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
// Handler 存储策略适配器
|
||||
Handler interface {
|
||||
// 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭
|
||||
// 时,应取消上传并清理临时文件
|
||||
Put(ctx context.Context, file *fs.UploadRequest) error
|
||||
|
||||
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
|
||||
Delete(ctx context.Context, files ...string) ([]string, error)
|
||||
|
||||
// Open physical files. Only implemented if HandlerCapabilityInboundGet capability is set.
|
||||
// Returns file path and an os.File object.
|
||||
Open(ctx context.Context, path string) (*os.File, error)
|
||||
|
||||
// LocalPath returns the local path of a file.
|
||||
// Only implemented if HandlerCapabilityInboundGet capability is set.
|
||||
LocalPath(ctx context.Context, path string) string
|
||||
|
||||
// Thumb returns the URL for a thumbnail of given entity.
|
||||
Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error)
|
||||
|
||||
// 获取外链/下载地址,
|
||||
// url - 站点本身地址,
|
||||
// isDownload - 是否直接下载
|
||||
Source(ctx context.Context, e fs.Entity, args *GetSourceArgs) (string, error)
|
||||
|
||||
// Token 获取有效期为ttl的上传凭证和签名
|
||||
Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error)
|
||||
|
||||
// CancelToken 取消已经创建的有状态上传凭证
|
||||
CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error
|
||||
|
||||
// CompleteUpload completes a previously created upload session.
|
||||
CompleteUpload(ctx context.Context, session *fs.UploadSession) error
|
||||
|
||||
// List 递归列取远程端path路径下文件、目录,不包含path本身,
|
||||
// 返回的对象路径以path作为起始根目录.
|
||||
// recursive - 是否递归列出
|
||||
// List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
|
||||
|
||||
// Capabilities returns the capabilities of this handler
|
||||
Capabilities() *Capabilities
|
||||
|
||||
// MediaMeta extracts media metadata from the given file.
|
||||
MediaMeta(ctx context.Context, path, ext string) ([]MediaMeta, error)
|
||||
}
|
||||
|
||||
Capabilities struct {
|
||||
StaticFeatures *boolset.BooleanSet
|
||||
// MaxSourceExpire indicates the maximum allowed expiration duration of a source URL
|
||||
MaxSourceExpire time.Duration
|
||||
// MinSourceExpire indicates the minimum allowed expiration duration of a source URL
|
||||
MinSourceExpire time.Duration
|
||||
// MediaMetaSupportedExts indicates the extensions of files that support media metadata. Empty list
|
||||
// indicates that no file supports extracting media metadata.
|
||||
MediaMetaSupportedExts []string
|
||||
// GenerateMediaMeta indicates whether to generate media metadata using local generators.
|
||||
MediaMetaProxy bool
|
||||
// ThumbSupportedExts indicates the extensions of files that support thumbnail generation. Empty list
|
||||
// indicates that no file supports thumbnail generation.
|
||||
ThumbSupportedExts []string
|
||||
// ThumbSupportAllExts indicates whether to generate thumbnails for all files, regardless of their extensions.
|
||||
ThumbSupportAllExts bool
|
||||
// ThumbMaxSize indicates the maximum allowed size of a thumbnail. 0 indicates that no limit is set.
|
||||
ThumbMaxSize int64
|
||||
// ThumbProxy indicates whether to generate thumbnails using local generators.
|
||||
ThumbProxy bool
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
MetaTypeExif MetaType = "exif"
|
||||
MediaTypeMusic MetaType = "music"
|
||||
MetaTypeStreamMedia MetaType = "stream"
|
||||
)
|
||||
|
||||
type ForceUsePublicEndpointCtx struct{}
|
||||
|
||||
// WithForcePublicEndpoint sets the context to force using public endpoint for supported storage policies.
|
||||
func WithForcePublicEndpoint(ctx context.Context, value bool) context.Context {
|
||||
return context.WithValue(ctx, ForceUsePublicEndpointCtx{}, value)
|
||||
}
|
||||
75
pkg/filemanager/driver/local/entity.go
Normal file
75
pkg/filemanager/driver/local/entity.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewLocalFileEntity creates a new local file entity.
|
||||
func NewLocalFileEntity(t types.EntityType, src string) (fs.Entity, error) {
|
||||
info, err := os.Stat(util.RelativePath(src))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &localFileEntity{
|
||||
t: t,
|
||||
src: src,
|
||||
size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type localFileEntity struct {
|
||||
t types.EntityType
|
||||
src string
|
||||
size int64
|
||||
}
|
||||
|
||||
func (l *localFileEntity) ID() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (l *localFileEntity) Type() types.EntityType {
|
||||
return l.t
|
||||
}
|
||||
|
||||
func (l *localFileEntity) Size() int64 {
|
||||
return l.size
|
||||
}
|
||||
|
||||
func (l *localFileEntity) UpdatedAt() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (l *localFileEntity) CreatedAt() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (l *localFileEntity) CreatedBy() *ent.User {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *localFileEntity) Source() string {
|
||||
return l.src
|
||||
}
|
||||
|
||||
func (l *localFileEntity) ReferenceCount() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (l *localFileEntity) PolicyID() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (l *localFileEntity) UploadSessionID() *uuid.UUID {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *localFileEntity) Model() *ent.Entity {
|
||||
return nil
|
||||
}
|
||||
11
pkg/filemanager/driver/local/fallocate.go
Normal file
11
pkg/filemanager/driver/local/fallocate.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !linux && !darwin
|
||||
// +build !linux,!darwin
|
||||
|
||||
package local
|
||||
|
||||
import "os"
|
||||
|
||||
// No-op on non-Linux/Darwin platforms.
|
||||
func Fallocate(file *os.File, offset int64, length int64) error {
|
||||
return nil
|
||||
}
|
||||
27
pkg/filemanager/driver/local/fallocate_darwin.go
Normal file
27
pkg/filemanager/driver/local/fallocate_darwin.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func Fallocate(file *os.File, offset int64, length int64) error {
|
||||
var fst syscall.Fstore_t
|
||||
|
||||
fst.Flags = syscall.F_ALLOCATECONTIG
|
||||
fst.Posmode = syscall.F_PREALLOCATE
|
||||
fst.Offset = 0
|
||||
fst.Length = offset + length
|
||||
fst.Bytesalloc = 0
|
||||
|
||||
// Check https://lists.apple.com/archives/darwin-dev/2007/Dec/msg00040.html
|
||||
_, _, err := syscall.Syscall(syscall.SYS_FCNTL, file.Fd(), syscall.F_PREALLOCATE, uintptr(unsafe.Pointer(&fst)))
|
||||
if err != syscall.Errno(0x0) {
|
||||
fst.Flags = syscall.F_ALLOCATEALL
|
||||
// Ignore the return value
|
||||
_, _, _ = syscall.Syscall(syscall.SYS_FCNTL, file.Fd(), syscall.F_PREALLOCATE, uintptr(unsafe.Pointer(&fst)))
|
||||
}
|
||||
|
||||
return syscall.Ftruncate(int(file.Fd()), fst.Length)
|
||||
}
|
||||
14
pkg/filemanager/driver/local/fallocate_linux.go
Normal file
14
pkg/filemanager/driver/local/fallocate_linux.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func Fallocate(file *os.File, offset int64, length int64) error {
|
||||
if length == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return syscall.Fallocate(int(file.Fd()), 0, offset, length)
|
||||
}
|
||||
301
pkg/filemanager/driver/local/local.go
Normal file
301
pkg/filemanager/driver/local/local.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
)
|
||||
|
||||
const (
|
||||
Perm = 0744
|
||||
)
|
||||
|
||||
var (
|
||||
capabilities = &driver.Capabilities{
|
||||
StaticFeatures: &boolset.BooleanSet{},
|
||||
MediaMetaProxy: true,
|
||||
ThumbProxy: true,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
boolset.Sets(map[driver.HandlerCapability]bool{
|
||||
driver.HandlerCapabilityProxyRequired: true,
|
||||
driver.HandlerCapabilityInboundGet: true,
|
||||
}, capabilities.StaticFeatures)
|
||||
}
|
||||
|
||||
// Driver 本地策略适配器
|
||||
type Driver struct {
|
||||
Policy *ent.StoragePolicy
|
||||
httpClient request.Client
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
}
|
||||
|
||||
// New constructs a new local driver
|
||||
func New(p *ent.StoragePolicy, l logging.Logger, config conf.ConfigProvider) *Driver {
|
||||
return &Driver{
|
||||
Policy: p,
|
||||
l: l,
|
||||
httpClient: request.NewClient(config, request.WithLogger(l)),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
//// List 递归列取给定物理路径下所有文件
|
||||
//func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
// var res []response.Object
|
||||
//
|
||||
// // 取得起始路径
|
||||
// root := util.RelativePath(filepath.FromSlash(path))
|
||||
//
|
||||
// // 开始遍历路径下的文件、目录
|
||||
// err := filepath.Walk(root,
|
||||
// func(path string, info os.FileInfo, err error) error {
|
||||
// // 跳过根目录
|
||||
// if path == root {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// if err != nil {
|
||||
// util.Log().Warning("Failed to walk folder %q: %s", path, err)
|
||||
// return filepath.SkipDir
|
||||
// }
|
||||
//
|
||||
// // 将遍历对象的绝对路径转换为相对路径
|
||||
// rel, err := filepath.Rel(root, path)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// res = append(res, response.Object{
|
||||
// Name: info.Name(),
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Source: path,
|
||||
// Size: uint64(info.Size()),
|
||||
// IsDir: info.IsDir(),
|
||||
// LastModify: info.ModTime(),
|
||||
// })
|
||||
//
|
||||
// // 如果非递归,则不步入目录
|
||||
// if !recursive && info.IsDir() {
|
||||
// return filepath.SkipDir
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
// })
|
||||
//
|
||||
// return res, err
|
||||
//}
|
||||
|
||||
// Get 获取文件内容
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
// 打开文件
|
||||
file, err := os.Open(handler.LocalPath(ctx, path))
|
||||
if err != nil {
|
||||
handler.l.Debug("Failed to open file: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return util.RelativePath(filepath.FromSlash(path))
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
dst := util.RelativePath(filepath.FromSlash(file.Props.SavePath))
|
||||
|
||||
// 如果非 Overwrite,则检查是否有重名冲突
|
||||
if file.Mode&fs.ModeOverwrite != fs.ModeOverwrite {
|
||||
if util.Exists(dst) {
|
||||
handler.l.Warning("File with the same name existed or unavailable: %s", dst)
|
||||
return errors.New("file with the same name existed or unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
if err := handler.prepareFileDirectory(dst); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
openMode := os.O_CREATE | os.O_RDWR
|
||||
if file.Mode&fs.ModeOverwrite == fs.ModeOverwrite && file.Offset == 0 {
|
||||
openMode |= os.O_TRUNC
|
||||
}
|
||||
|
||||
out, err := os.OpenFile(dst, openMode, Perm)
|
||||
if err != nil {
|
||||
handler.l.Warning("Failed to open or create file: %s", err)
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
stat, err := out.Stat()
|
||||
if err != nil {
|
||||
handler.l.Warning("Failed to read file info: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if stat.Size() < file.Offset {
|
||||
return errors.New("size of unfinished uploaded chunks is not as expected")
|
||||
}
|
||||
|
||||
if _, err := out.Seek(file.Offset, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek to desired offset %d: %s", file.Offset, err)
|
||||
}
|
||||
|
||||
// 写入文件内容
|
||||
_, err = io.Copy(out, file)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
deleteFailed := make([]string, 0, len(files))
|
||||
var retErr error
|
||||
|
||||
for _, value := range files {
|
||||
filePath := util.RelativePath(filepath.FromSlash(value))
|
||||
if util.Exists(filePath) {
|
||||
err := os.Remove(filePath)
|
||||
if err != nil {
|
||||
handler.l.Warning("Failed to delete file: %s", err)
|
||||
retErr = err
|
||||
deleteFailed = append(deleteFailed, value)
|
||||
}
|
||||
}
|
||||
|
||||
//// 尝试删除文件的缩略图(如果有)
|
||||
//_ = os.Remove(util.RelativePath(value + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")))
|
||||
}
|
||||
|
||||
return deleteFailed, retErr
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token,本地策略直接返回空值
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
if file.Mode&fs.ModeOverwrite != fs.ModeOverwrite && util.Exists(uploadSession.Props.SavePath) {
|
||||
return nil, errors.New("placeholder file already exist")
|
||||
}
|
||||
|
||||
dst := util.RelativePath(filepath.FromSlash(uploadSession.Props.SavePath))
|
||||
if err := handler.prepareFileDirectory(dst); err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare file directory: %w", err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, Perm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create placeholder file: %w", err)
|
||||
}
|
||||
|
||||
// Preallocate disk space
|
||||
defer f.Close()
|
||||
if handler.Policy.Settings.PreAllocate {
|
||||
if err := Fallocate(f, 0, uploadSession.Props.Size); err != nil {
|
||||
handler.l.Warning("Failed to preallocate file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &fs.UploadCredential{
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: handler.Policy.Settings.ChunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Driver) prepareFileDirectory(dst string) error {
|
||||
basePath := filepath.Dir(dst)
|
||||
if !util.Exists(basePath) {
|
||||
err := os.MkdirAll(basePath, Perm)
|
||||
if err != nil {
|
||||
h.l.Warning("Failed to create directory: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
if session.Callback == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if session.Policy.Edges.Node == nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Node not found", nil)
|
||||
}
|
||||
|
||||
// If callback is set, indicating this handler is used in slave node as a shadowed handler for remote policy,
|
||||
// we need to send callback request to master node.
|
||||
resp := handler.httpClient.Request(
|
||||
"POST",
|
||||
session.Callback,
|
||||
nil,
|
||||
request.WithTimeout(time.Duration(handler.config.Slave().CallbackTimeout)*time.Second),
|
||||
request.WithCredential(
|
||||
auth.HMACAuth{[]byte(session.Policy.Edges.Node.SlaveKey)},
|
||||
int64(handler.config.Slave().SignatureTTL),
|
||||
),
|
||||
request.WithContext(ctx),
|
||||
request.WithCorrelationID(),
|
||||
)
|
||||
|
||||
if resp.Err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
|
||||
}
|
||||
|
||||
// 解析回调服务端响应
|
||||
res, err := resp.DecodeResponse()
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
|
||||
return serializer.NewError(serializer.CodeCallbackError, msg, err)
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewError(res.Code, res.Msg, errors.New(res.Error))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
137
pkg/filemanager/driver/obs/media.go
Normal file
137
pkg/filemanager/driver/obs/media.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package obs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/mediameta"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/huaweicloud/huaweicloud-sdk-go-obs/obs"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func (d *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
thumbURL, err := d.signSourceURL(&obs.CreateSignedUrlInput{
|
||||
Method: obs.HttpMethodGet,
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: path,
|
||||
Expires: int(mediaInfoTTL.Seconds()),
|
||||
QueryParams: map[string]string{
|
||||
imageProcessHeader: imageInfoProcessor,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign media info url: %w", err)
|
||||
}
|
||||
|
||||
resp, err := d.httpClient.
|
||||
Request(http.MethodGet, thumbURL, nil, request.WithContext(ctx)).
|
||||
CheckHTTPResponse(http.StatusOK).
|
||||
GetResponseIgnoreErr()
|
||||
if err != nil {
|
||||
return nil, handleJsonError(resp, err)
|
||||
}
|
||||
|
||||
var imageInfo map[string]any
|
||||
if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
imageInfoMap := lo.MapEntries(imageInfo, func(k string, v any) (string, string) {
|
||||
if vStr, ok := v.(string); ok {
|
||||
return strings.TrimPrefix(k, "exif:"), vStr
|
||||
}
|
||||
|
||||
return k, fmt.Sprintf("%v", v)
|
||||
})
|
||||
metas := make([]driver.MediaMeta, 0)
|
||||
metas = append(metas, mediameta.ExtractExifMap(imageInfoMap, time.Time{})...)
|
||||
metas = append(metas, parseGpsInfo(imageInfoMap)...)
|
||||
for i := 0; i < len(metas); i++ {
|
||||
metas[i].Type = driver.MetaTypeExif
|
||||
}
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func parseGpsInfo(imageInfo map[string]string) []driver.MediaMeta {
|
||||
latitude := imageInfo["GPSLatitude"] // 31/1, 162680820/10000000, 0/1
|
||||
longitude := imageInfo["GPSLongitude"] // 120/1, 429103939/10000000, 0/1
|
||||
latRef := imageInfo["GPSLatitudeRef"] // N
|
||||
lonRef := imageInfo["GPSLongitudeRef"] // E
|
||||
|
||||
// Make sure all value exist in map
|
||||
if latitude == "" || longitude == "" || latRef == "" || lonRef == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lat := parseRawGPS(latitude, latRef)
|
||||
lon := parseRawGPS(longitude, lonRef)
|
||||
if !math.IsNaN(lat) && !math.IsNaN(lon) {
|
||||
lat, lng := mediameta.NormalizeGPS(lat, lon)
|
||||
return []driver.MediaMeta{{
|
||||
Key: mediameta.GpsLat,
|
||||
Value: fmt.Sprintf("%f", lat),
|
||||
}, {
|
||||
Key: mediameta.GpsLng,
|
||||
Value: fmt.Sprintf("%f", lng),
|
||||
}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRawGPS(gpsStr string, ref string) float64 {
|
||||
elem := strings.Split(gpsStr, ", ")
|
||||
if len(elem) < 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var (
|
||||
deg float64
|
||||
minutes float64
|
||||
seconds float64
|
||||
)
|
||||
|
||||
deg = getGpsElemValue(elem[0])
|
||||
if len(elem) >= 2 {
|
||||
minutes = getGpsElemValue(elem[1])
|
||||
}
|
||||
if len(elem) >= 3 {
|
||||
seconds = getGpsElemValue(elem[2])
|
||||
}
|
||||
|
||||
decimal := deg + minutes/60.0 + seconds/3600.0
|
||||
|
||||
if ref == "S" || ref == "W" {
|
||||
return -decimal
|
||||
}
|
||||
|
||||
return decimal
|
||||
}
|
||||
|
||||
func getGpsElemValue(elm string) float64 {
|
||||
elements := strings.Split(elm, "/")
|
||||
if len(elements) != 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
numerator, err := strconv.ParseFloat(elements[0], 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
denominator, err := strconv.ParseFloat(elements[1], 64)
|
||||
if err != nil || denominator == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return numerator / denominator
|
||||
}
|
||||
513
pkg/filemanager/driver/obs/obs.go
Normal file
513
pkg/filemanager/driver/obs/obs.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package obs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/huaweicloud/huaweicloud-sdk-go-obs/obs"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
maxDeleteBatch = 1000
|
||||
imageProcessHeader = "x-image-process"
|
||||
trafficLimitHeader = "x-obs-traffic-limit"
|
||||
partNumberParam = "partNumber"
|
||||
callbackParam = "x-obs-callback"
|
||||
uploadIdParam = "uploadId"
|
||||
mediaInfoTTL = time.Duration(10) * time.Minute
|
||||
imageInfoProcessor = "image/info"
|
||||
|
||||
// MultiPartUploadThreshold 服务端使用分片上传的阈值
|
||||
MultiPartUploadThreshold int64 = 5 << 30 // 5GB
|
||||
)
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
type (
|
||||
CallbackPolicy struct {
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
CallbackBody string `json:"callbackBody"`
|
||||
CallbackBodyType string `json:"callbackBodyType"`
|
||||
}
|
||||
JsonError struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
)
|
||||
|
||||
// Driver Huawei Cloud OBS driver
|
||||
type Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
chunkSize int64
|
||||
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
mime mime.MimeDetector
|
||||
httpClient request.Client
|
||||
obs *obs.ObsClient
|
||||
}
|
||||
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) {
|
||||
chunkSize := policy.Settings.ChunkSize
|
||||
if policy.Settings.ChunkSize == 0 {
|
||||
chunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
driver := &Driver{
|
||||
policy: policy,
|
||||
settings: settings,
|
||||
chunkSize: chunkSize,
|
||||
config: config,
|
||||
l: l,
|
||||
mime: mime,
|
||||
httpClient: request.NewClient(config, request.WithLogger(l)),
|
||||
}
|
||||
|
||||
useCname := false
|
||||
if policy.Settings != nil && policy.Settings.UseCname {
|
||||
useCname = true
|
||||
}
|
||||
|
||||
obsClient, err := obs.New(policy.AccessKey, policy.SecretKey, policy.Server, obs.WithSignature(obs.SignatureObs), obs.WithCustomDomainName(useCname))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driver.obs = obsClient
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
func (d *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
if !overwrite {
|
||||
// Check for duplicated file
|
||||
if _, err := d.obs.HeadObject(&obs.HeadObjectInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
}, obs.WithRequestContext(ctx)); err == nil {
|
||||
return fs.ErrFileExisted
|
||||
}
|
||||
}
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
d.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
// 小文件直接上传
|
||||
if file.Props.Size < MultiPartUploadThreshold {
|
||||
_, err := d.obs.PutObject(&obs.PutObjectInput{
|
||||
PutObjectBasicInput: obs.PutObjectBasicInput{
|
||||
ObjectOperationInput: obs.ObjectOperationInput{
|
||||
Key: file.Props.SavePath,
|
||||
Bucket: d.policy.BucketName,
|
||||
},
|
||||
HttpHeader: obs.HttpHeader{
|
||||
ContentType: mimeType,
|
||||
},
|
||||
ContentLength: file.Props.Size,
|
||||
},
|
||||
Body: file,
|
||||
}, obs.WithRequestContext(ctx))
|
||||
return err
|
||||
}
|
||||
|
||||
// 超过阈值时使用分片上传
|
||||
imur, err := d.obs.InitiateMultipartUpload(&obs.InitiateMultipartUploadInput{
|
||||
ObjectOperationInput: obs.ObjectOperationInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
},
|
||||
HttpHeader: obs.HttpHeader{
|
||||
ContentType: d.mime.TypeByName(file.Props.Uri.Name()),
|
||||
},
|
||||
}, obs.WithRequestContext(ctx))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initiate multipart upload: %w", err)
|
||||
}
|
||||
|
||||
chunks := chunk.NewChunkGroup(file, d.chunkSize, &backoff.ConstantBackoff{
|
||||
Max: d.settings.ChunkRetryLimit(ctx),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, d.settings.UseChunkBuffer(ctx), d.l, d.settings.TempPath(ctx))
|
||||
|
||||
parts := make([]*obs.UploadPartOutput, 0, chunks.Num())
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
part, err := d.obs.UploadPart(&obs.UploadPartInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
PartNumber: current.Index() + 1,
|
||||
UploadId: imur.UploadId,
|
||||
Body: content,
|
||||
SourceFile: "",
|
||||
PartSize: current.Length(),
|
||||
}, obs.WithRequestContext(ctx))
|
||||
if err == nil {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
d.cancelUpload(file.Props.SavePath, imur)
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = d.obs.CompleteMultipartUpload(&obs.CompleteMultipartUploadInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
UploadId: imur.UploadId,
|
||||
Parts: lo.Map(parts, func(part *obs.UploadPartOutput, i int) obs.Part {
|
||||
return obs.Part{
|
||||
PartNumber: i + 1,
|
||||
ETag: part.ETag,
|
||||
}
|
||||
}),
|
||||
}, obs.WithRequestContext(ctx))
|
||||
if err != nil {
|
||||
d.cancelUpload(file.Props.SavePath, imur)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
groups := lo.Chunk(files, maxDeleteBatch)
|
||||
failed := make([]string, 0)
|
||||
var lastError error
|
||||
for index, group := range groups {
|
||||
d.l.Debug("Process delete group #%d: %v", index, group)
|
||||
// 删除文件
|
||||
delRes, err := d.obs.DeleteObjects(&obs.DeleteObjectsInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Quiet: true,
|
||||
Objects: lo.Map(group, func(item string, index int) obs.ObjectToDelete {
|
||||
return obs.ObjectToDelete{
|
||||
Key: item,
|
||||
}
|
||||
}),
|
||||
}, obs.WithRequestContext(ctx))
|
||||
if err != nil {
|
||||
failed = append(failed, group...)
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range delRes.Errors {
|
||||
d.l.Debug("Failed to delete file: %s, Code:%s, Message:%s", v.Key, v.Code, v.Key)
|
||||
failed = append(failed, v.Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(failed) > 0 && lastError == nil {
|
||||
lastError = fmt.Errorf("failed to delete files: %v", failed)
|
||||
}
|
||||
|
||||
return failed, lastError
|
||||
}
|
||||
|
||||
func (d *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (d *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
w, h := d.settings.ThumbSize(ctx)
|
||||
thumbURL, err := d.signSourceURL(&obs.CreateSignedUrlInput{
|
||||
Method: obs.HttpMethodGet,
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: e.Source(),
|
||||
Expires: int(time.Until(*expire).Seconds()),
|
||||
QueryParams: map[string]string{
|
||||
imageProcessHeader: fmt.Sprintf("image/resize,m_lfit,w_%d,h_%d", w, h),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return thumbURL, nil
|
||||
}
|
||||
|
||||
func (d *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
params := make(map[string]string)
|
||||
if args.IsDownload {
|
||||
encodedFilename := url.PathEscape(args.DisplayName)
|
||||
params["response-content-disposition"] = fmt.Sprintf("attachment; filename=\"%s\"; filename*=UTF-8''%s",
|
||||
args.DisplayName, encodedFilename)
|
||||
}
|
||||
|
||||
expires := 86400 * 265 * 20
|
||||
if args.Expire != nil {
|
||||
expires = int(time.Until(*args.Expire).Seconds())
|
||||
}
|
||||
|
||||
if args.Speed > 0 {
|
||||
// Byte 转换为 bit
|
||||
args.Speed *= 8
|
||||
|
||||
// OSS对速度值有范围限制
|
||||
if args.Speed < 819200 {
|
||||
args.Speed = 819200
|
||||
}
|
||||
if args.Speed > 838860800 {
|
||||
args.Speed = 838860800
|
||||
}
|
||||
}
|
||||
|
||||
if args.Speed > 0 {
|
||||
params[trafficLimitHeader] = strconv.FormatInt(args.Speed, 10)
|
||||
}
|
||||
|
||||
return d.signSourceURL(&obs.CreateSignedUrlInput{
|
||||
Method: obs.HttpMethodGet,
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: e.Source(),
|
||||
Expires: expires,
|
||||
QueryParams: params,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
// Check for duplicated file
|
||||
if _, err := d.obs.HeadObject(&obs.HeadObjectInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
}, obs.WithRequestContext(ctx)); err == nil {
|
||||
return nil, fs.ErrFileExisted
|
||||
}
|
||||
|
||||
// 生成回调地址
|
||||
siteURL := d.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
// 在从机端创建上传会话
|
||||
uploadSession.ChunkSize = d.chunkSize
|
||||
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeObs, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
// 回调策略
|
||||
callbackPolicy := CallbackPolicy{
|
||||
CallbackURL: uploadSession.Callback,
|
||||
CallbackBody: `{"name":${key},"source_name":${fname},"size":${size}}`,
|
||||
CallbackBodyType: "application/json",
|
||||
}
|
||||
|
||||
callbackPolicyJSON, err := json.Marshal(callbackPolicy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode callback policy: %w", err)
|
||||
}
|
||||
callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON)
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
d.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
imur, err := d.obs.InitiateMultipartUpload(&obs.InitiateMultipartUploadInput{
|
||||
ObjectOperationInput: obs.ObjectOperationInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
},
|
||||
HttpHeader: obs.HttpHeader{
|
||||
ContentType: mimeType,
|
||||
},
|
||||
}, obs.WithRequestContext(ctx))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize multipart upload: %w", err)
|
||||
}
|
||||
uploadSession.UploadID = imur.UploadId
|
||||
|
||||
// 为每个分片签名上传 URL
|
||||
chunks := chunk.NewChunkGroup(file, d.chunkSize, &backoff.ConstantBackoff{}, false, d.l, "")
|
||||
urls := make([]string, chunks.Num())
|
||||
ttl := int64(time.Until(uploadSession.Props.ExpireAt).Seconds())
|
||||
for chunks.Next() {
|
||||
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
|
||||
signedURL, err := d.obs.CreateSignedUrl(&obs.CreateSignedUrlInput{
|
||||
Method: obs.HttpMethodPut,
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
QueryParams: map[string]string{
|
||||
partNumberParam: strconv.Itoa(c.Index() + 1),
|
||||
uploadIdParam: uploadSession.UploadID,
|
||||
},
|
||||
Expires: int(ttl),
|
||||
Headers: map[string]string{
|
||||
"Content-Length": strconv.FormatInt(c.Length(), 10),
|
||||
"Content-Type": "application/octet-stream",
|
||||
}, //TODO: Validate +1
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
urls[c.Index()] = signedURL.SignedUrl
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 签名完成分片上传的URL
|
||||
completeURL, err := d.obs.CreateSignedUrl(&obs.CreateSignedUrlInput{
|
||||
Method: obs.HttpMethodPost,
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: file.Props.SavePath,
|
||||
QueryParams: map[string]string{
|
||||
uploadIdParam: uploadSession.UploadID,
|
||||
callbackParam: callbackPolicyEncoded,
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/octet-stream",
|
||||
},
|
||||
Expires: int(ttl),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fs.UploadCredential{
|
||||
UploadID: imur.UploadId,
|
||||
UploadURLs: urls,
|
||||
CompleteURL: completeURL.SignedUrl,
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: d.chunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
_, err := d.obs.AbortMultipartUpload(&obs.AbortMultipartUploadInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: uploadSession.Props.SavePath,
|
||||
UploadId: uploadSession.UploadID,
|
||||
}, obs.WithRequestContext(ctx))
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
// return nil, errors.New("not implemented")
|
||||
//}
|
||||
|
||||
func (d *Driver) Capabilities() *driver.Capabilities {
|
||||
mediaMetaExts := d.policy.Settings.MediaMetaExts
|
||||
if !d.policy.Settings.NativeMediaProcessing {
|
||||
mediaMetaExts = nil
|
||||
}
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaSupportedExts: mediaMetaExts,
|
||||
MediaMetaProxy: d.policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbSupportedExts: d.policy.Settings.ThumbExts,
|
||||
ThumbProxy: d.policy.Settings.ThumbGeneratorProxy,
|
||||
ThumbSupportAllExts: d.policy.Settings.ThumbSupportAllExts,
|
||||
ThumbMaxSize: d.policy.Settings.ThumbMaxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (d *Driver) CORS() error {
|
||||
_, err := d.obs.SetBucketCors(&obs.SetBucketCorsInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
BucketCors: obs.BucketCors{
|
||||
CorsRules: []obs.CorsRule{
|
||||
{
|
||||
AllowedOrigin: []string{"*"},
|
||||
AllowedMethod: []string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
},
|
||||
ExposeHeader: []string{"Etag"},
|
||||
AllowedHeader: []string{"*"},
|
||||
MaxAgeSeconds: 3600,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Driver) cancelUpload(path string, imur *obs.InitiateMultipartUploadOutput) {
|
||||
if _, err := d.obs.AbortMultipartUpload(&obs.AbortMultipartUploadInput{
|
||||
Bucket: d.policy.BucketName,
|
||||
Key: path,
|
||||
UploadId: imur.UploadId,
|
||||
}); err != nil {
|
||||
d.l.Warning("failed to abort multipart upload: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) signSourceURL(input *obs.CreateSignedUrlInput) (string, error) {
|
||||
signedURL, err := handler.obs.CreateSignedUrl(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
finalURL, err := url.Parse(signedURL.SignedUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 公有空间替换掉Key及不支持的头
|
||||
if !handler.policy.IsPrivate {
|
||||
query := finalURL.Query()
|
||||
query.Del("AccessKeyId")
|
||||
query.Del("Signature")
|
||||
finalURL.RawQuery = query.Encode()
|
||||
}
|
||||
return finalURL.String(), nil
|
||||
}
|
||||
|
||||
func handleJsonError(resp string, originErr error) error {
|
||||
if resp == "" {
|
||||
return originErr
|
||||
}
|
||||
|
||||
var err JsonError
|
||||
if err := json.Unmarshal([]byte(resp), &err); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal cos error: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("obs error: %s", err.Message)
|
||||
}
|
||||
@@ -4,23 +4,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -35,6 +28,8 @@ const (
|
||||
notFoundError = "itemNotFound"
|
||||
)
|
||||
|
||||
type RetryCtx struct{}
|
||||
|
||||
// GetSourcePath 获取文件的绝对路径
|
||||
func (info *FileInfo) GetSourcePath() string {
|
||||
res, err := url.PathUnescape(info.ParentReference.Path)
|
||||
@@ -51,19 +46,19 @@ func (info *FileInfo) GetSourcePath() string {
|
||||
)
|
||||
}
|
||||
|
||||
func (client *Client) getRequestURL(api string, opts ...Option) string {
|
||||
func (client *client) getRequestURL(api string, opts ...Option) string {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
base, _ := url.Parse(client.Endpoints.EndpointURL)
|
||||
base, _ := url.Parse(client.endpoints.endpointURL)
|
||||
if base == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if options.useDriverResource {
|
||||
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
|
||||
base.Path = path.Join(base.Path, client.endpoints.driverResource, api)
|
||||
} else {
|
||||
base.Path = path.Join(base.Path, api)
|
||||
}
|
||||
@@ -72,7 +67,7 @@ func (client *Client) getRequestURL(api string, opts ...Option) string {
|
||||
}
|
||||
|
||||
// ListChildren 根据路径列取子对象
|
||||
func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
|
||||
func (client *client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
|
||||
var requestURL string
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
if dst == "" {
|
||||
@@ -84,14 +79,14 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
|
||||
if err != nil {
|
||||
retried := 0
|
||||
if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok {
|
||||
if v, ok := ctx.Value(RetryCtx{}).(int); ok {
|
||||
retried = v
|
||||
}
|
||||
if retried < ListRetry {
|
||||
retried++
|
||||
util.Log().Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
|
||||
client.l.Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
return client.ListChildren(context.WithValue(ctx, fsctx.RetryCtx, retried), path)
|
||||
return client.ListChildren(context.WithValue(ctx, RetryCtx{}, retried), path)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -109,7 +104,7 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
}
|
||||
|
||||
// Meta 根据资源ID或文件路径获取文件元信息
|
||||
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
|
||||
func (client *client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
|
||||
var requestURL string
|
||||
if id != "" {
|
||||
requestURL = client.getRequestURL("items/" + id)
|
||||
@@ -137,7 +132,7 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn
|
||||
}
|
||||
|
||||
// CreateUploadSession 创建分片上传会话
|
||||
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
|
||||
func (client *client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
@@ -170,7 +165,7 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts
|
||||
}
|
||||
|
||||
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
|
||||
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
|
||||
func (client *client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
|
||||
siteUrlParsed, err := url.Parse(siteUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -197,7 +192,7 @@ func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (strin
|
||||
}
|
||||
|
||||
// GetUploadSessionStatus 查询上传会话状态
|
||||
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
|
||||
func (client *client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
|
||||
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -216,7 +211,7 @@ func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL stri
|
||||
}
|
||||
|
||||
// UploadChunk 上传分片
|
||||
func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
|
||||
func (client *client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
|
||||
res, err := client.request(
|
||||
ctx, "PUT", uploadURL, content,
|
||||
request.WithContentLength(current.Length()),
|
||||
@@ -247,16 +242,15 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content
|
||||
}
|
||||
|
||||
// Upload 上传文件
|
||||
func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
fileInfo := file.Info()
|
||||
func (client *client) Upload(ctx context.Context, file *fs.UploadRequest) error {
|
||||
// 决定是否覆盖文件
|
||||
overwrite := "fail"
|
||||
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
|
||||
if file.Mode&fs.ModeOverwrite == fs.ModeOverwrite {
|
||||
overwrite = "replace"
|
||||
}
|
||||
|
||||
size := int(fileInfo.Size)
|
||||
dst := fileInfo.SavePath
|
||||
size := int(file.Props.Size)
|
||||
dst := file.Props.SavePath
|
||||
|
||||
// 小文件,使用简单上传接口上传
|
||||
if size <= int(SmallFileSize) {
|
||||
@@ -272,10 +266,10 @@ func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
}
|
||||
|
||||
// Initial chunk groups
|
||||
chunks := chunk.NewChunkGroup(file, client.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
|
||||
Max: model.GetIntSetting("chunk_retries", 5),
|
||||
chunks := chunk.NewChunkGroup(file, client.chunkSize, &backoff.ConstantBackoff{
|
||||
Max: client.settings.ChunkRetryLimit(ctx),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
|
||||
}, client.settings.UseChunkBuffer(ctx), client.l, client.settings.TempPath(ctx))
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
_, err := client.UploadChunk(ctx, uploadURL, content, current)
|
||||
@@ -285,6 +279,9 @@ func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
// upload chunks
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
if err := client.DeleteUploadSession(ctx, uploadURL); err != nil {
|
||||
client.l.Warning("Failed to delete upload session: %s", err)
|
||||
}
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
@@ -293,7 +290,7 @@ func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
}
|
||||
|
||||
// DeleteUploadSession 删除上传会话
|
||||
func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
|
||||
func (client *client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
|
||||
_, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -303,7 +300,7 @@ func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string)
|
||||
}
|
||||
|
||||
// SimpleUpload 上传小文件到dst
|
||||
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
|
||||
func (client *client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
@@ -334,8 +331,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read
|
||||
|
||||
// BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为
|
||||
// 20个一组,调用Delete并行删除
|
||||
// TODO 测试
|
||||
func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
|
||||
func (client *client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
|
||||
groupNum := len(dst)/20 + 1
|
||||
finalRes := make([]string, 0, len(dst))
|
||||
res := make([]string, 0, 20)
|
||||
@@ -346,6 +342,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string,
|
||||
if i == groupNum-1 {
|
||||
end = len(dst)
|
||||
}
|
||||
|
||||
client.l.Debug("Delete file group: %v.", dst[20*i:end])
|
||||
res, err = client.Delete(ctx, dst[20*i:end])
|
||||
finalRes = append(finalRes, res...)
|
||||
}
|
||||
@@ -355,7 +353,7 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string,
|
||||
|
||||
// Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误,
|
||||
// 由于API限制,最多删除20个
|
||||
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
|
||||
func (client *client) Delete(ctx context.Context, dst []string) ([]string, error) {
|
||||
body := client.makeBatchDeleteRequestsBody(dst)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
|
||||
WithDriverResource(false)), body, 200)
|
||||
@@ -391,13 +389,13 @@ func getDeleteFailed(res *BatchResponses) []string {
|
||||
}
|
||||
|
||||
// makeBatchDeleteRequestsBody 生成批量删除请求正文
|
||||
func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
func (client *client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
req := BatchRequests{
|
||||
Requests: make([]BatchRequest, len(files)),
|
||||
}
|
||||
for i, v := range files {
|
||||
v = strings.TrimPrefix(v, "/")
|
||||
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
|
||||
filePath, _ := url.Parse("/" + client.endpoints.driverResource + "/root:/")
|
||||
filePath.Path = path.Join(filePath.Path, v)
|
||||
req.Requests[i] = BatchRequest{
|
||||
ID: v,
|
||||
@@ -411,7 +409,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
}
|
||||
|
||||
// GetThumbURL 获取给定尺寸的缩略图URL
|
||||
func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
|
||||
func (client *client) GetThumbURL(ctx context.Context, dst string) (string, error) {
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
|
||||
|
||||
@@ -442,82 +440,6 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s
|
||||
return "", ErrThumbSizeNotFound
|
||||
}
|
||||
|
||||
// MonitorUpload 监控客户端分片上传进度
|
||||
func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size uint64, ttl int64) {
|
||||
// 回调完成通知chan
|
||||
callbackChan := mq.GlobalMQ.Subscribe(callbackKey, 1)
|
||||
defer mq.GlobalMQ.Unsubscribe(callbackKey, callbackChan)
|
||||
|
||||
timeout := model.GetIntSetting("onedrive_monitor_timeout", 600)
|
||||
interval := model.GetIntSetting("onedrive_callback_check", 20)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-callbackChan:
|
||||
util.Log().Debug("Client finished OneDrive callback.")
|
||||
return
|
||||
case <-time.After(time.Duration(ttl) * time.Second):
|
||||
// 上传会话到期,仍未完成上传,创建占位符
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to create placeholder file: %s", err)
|
||||
}
|
||||
return
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
util.Log().Debug("Checking OneDrive upload status.")
|
||||
status, err := client.GetUploadSessionStatus(context.Background(), uploadURL)
|
||||
|
||||
if err != nil {
|
||||
if resErr, ok := err.(*RespError); ok {
|
||||
if resErr.APIError.Code == notFoundError {
|
||||
util.Log().Debug("Upload completed, will check upload callback later.")
|
||||
select {
|
||||
case <-time.After(time.Duration(interval) * time.Second):
|
||||
util.Log().Warning("No callback is made, file will be deleted.")
|
||||
cache.Deletes([]string{callbackKey}, "callback_")
|
||||
_, err = client.Delete(context.Background(), []string{path})
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to delete file without callback: %s", err)
|
||||
}
|
||||
case <-callbackChan:
|
||||
util.Log().Debug("Client finished callback.")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
util.Log().Debug("Failed to get upload session status: %s, continue next iteration.", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 成功获取分片上传状态,检查文件大小
|
||||
if len(status.NextExpectedRanges) == 0 {
|
||||
continue
|
||||
}
|
||||
sizeRange := strings.Split(
|
||||
status.NextExpectedRanges[len(status.NextExpectedRanges)-1],
|
||||
"-",
|
||||
)
|
||||
if len(sizeRange) != 2 {
|
||||
continue
|
||||
}
|
||||
uploadFullSize, _ := strconv.ParseUint(sizeRange[1], 10, 64)
|
||||
if (sizeRange[0] == "0" && sizeRange[1] == "") || uploadFullSize+1 != size {
|
||||
util.Log().Debug("Upload has not started, or uploaded file size not match, canceling upload session...")
|
||||
// 取消上传会话,实测OneDrive取消上传会话后,客户端还是可以上传,
|
||||
// 所以上传一个空文件占位,阻止客户端上传
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("无法创建占位文件,%s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sysError(err error) *RespError {
|
||||
return &RespError{APIError: APIError{
|
||||
Code: "system",
|
||||
@@ -525,32 +447,32 @@ func sysError(err error) *RespError {
|
||||
}}
|
||||
}
|
||||
|
||||
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
|
||||
func (client *client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
|
||||
// 获取凭证
|
||||
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
|
||||
err := client.UpdateCredential(ctx)
|
||||
if err != nil {
|
||||
return "", sysError(err)
|
||||
}
|
||||
|
||||
option = append(option,
|
||||
opts := []request.Option{
|
||||
request.WithHeader(http.Header{
|
||||
"Authorization": {"Bearer " + client.Credential.AccessToken},
|
||||
"Authorization": {"Bearer " + client.credential.String()},
|
||||
"Content-Type": {"application/json"},
|
||||
}),
|
||||
request.WithContext(ctx),
|
||||
request.WithTPSLimit(
|
||||
fmt.Sprintf("policy_%d", client.Policy.ID),
|
||||
client.Policy.OptionsSerialized.TPSLimit,
|
||||
client.Policy.OptionsSerialized.TPSLimitBurst,
|
||||
fmt.Sprintf("policy_%d", client.policy.ID),
|
||||
client.policy.Settings.TPSLimit,
|
||||
client.policy.Settings.TPSLimitBurst,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
res := client.Request.Request(
|
||||
res := client.httpClient.Request(
|
||||
method,
|
||||
url,
|
||||
body,
|
||||
option...,
|
||||
append(opts, option...)...,
|
||||
)
|
||||
|
||||
if res.Err != nil {
|
||||
@@ -571,12 +493,12 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
|
||||
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||
if decodeErr != nil {
|
||||
util.Log().Debug("Onedrive returns unknown response: %s", respBody)
|
||||
client.l.Debug("Onedrive returns unknown response: %s", respBody)
|
||||
return "", sysError(decodeErr)
|
||||
}
|
||||
|
||||
if res.Response.StatusCode == 429 {
|
||||
util.Log().Warning("OneDrive request is throttled.")
|
||||
client.l.Warning("OneDrive request is throttled.")
|
||||
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
|
||||
}
|
||||
|
||||
@@ -586,7 +508,7 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
|
||||
func (client *client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
|
||||
// 发送请求
|
||||
bodyReader := io.NopCloser(strings.NewReader(body))
|
||||
return client.request(ctx, method, url, bodyReader,
|
||||
90
pkg/filemanager/driver/onedrive/client.go
Normal file
90
pkg/filemanager/driver/onedrive/client.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAuthEndpoint 无法解析授权端点地址
|
||||
ErrAuthEndpoint = errors.New("failed to parse endpoint url")
|
||||
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
|
||||
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
|
||||
// ErrDeleteFile 无法删除文件
|
||||
ErrDeleteFile = errors.New("cannot delete file")
|
||||
// ErrClientCanceled 客户端取消操作
|
||||
ErrClientCanceled = errors.New("client canceled")
|
||||
// Desired thumb size not available
|
||||
ErrThumbSizeNotFound = errors.New("thumb size not found")
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
ListChildren(ctx context.Context, path string) ([]FileInfo, error)
|
||||
Meta(ctx context.Context, id string, path string) (*FileInfo, error)
|
||||
CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error)
|
||||
GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error)
|
||||
GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error)
|
||||
Upload(ctx context.Context, file *fs.UploadRequest) error
|
||||
SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error)
|
||||
DeleteUploadSession(ctx context.Context, uploadURL string) error
|
||||
BatchDelete(ctx context.Context, dst []string) ([]string, error)
|
||||
GetThumbURL(ctx context.Context, dst string) (string, error)
|
||||
OAuthURL(ctx context.Context, scopes []string) string
|
||||
ObtainToken(ctx context.Context, opts ...Option) (*Credential, error)
|
||||
}
|
||||
|
||||
// client OneDrive客户端
|
||||
type client struct {
|
||||
endpoints *endpoints
|
||||
policy *ent.StoragePolicy
|
||||
credential credmanager.Credential
|
||||
|
||||
httpClient request.Client
|
||||
cred credmanager.CredManager
|
||||
l logging.Logger
|
||||
settings setting.Provider
|
||||
|
||||
chunkSize int64
|
||||
}
|
||||
|
||||
// endpoints OneDrive客户端相关设置
|
||||
type endpoints struct {
|
||||
oAuthEndpoints *oauthEndpoint
|
||||
endpointURL string // 接口请求的基URL
|
||||
driverResource string // 要使用的驱动器
|
||||
}
|
||||
|
||||
// NewClient 根据存储策略获取新的client
|
||||
func NewClient(policy *ent.StoragePolicy, httpClient request.Client, cred credmanager.CredManager,
|
||||
l logging.Logger, settings setting.Provider, chunkSize int64) Client {
|
||||
client := &client{
|
||||
endpoints: &endpoints{
|
||||
endpointURL: policy.Server,
|
||||
driverResource: policy.Settings.OdDriver,
|
||||
},
|
||||
policy: policy,
|
||||
httpClient: httpClient,
|
||||
cred: cred,
|
||||
l: l,
|
||||
settings: settings,
|
||||
chunkSize: chunkSize,
|
||||
}
|
||||
|
||||
if client.endpoints.driverResource == "" {
|
||||
client.endpoints.driverResource = "me/drive"
|
||||
}
|
||||
|
||||
oauthBase := getOAuthEndpoint(policy.Server)
|
||||
client.endpoints.oAuthEndpoints = oauthBase
|
||||
|
||||
return client
|
||||
}
|
||||
271
pkg/filemanager/driver/onedrive/oauth.go
Normal file
271
pkg/filemanager/driver/onedrive/oauth.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTokenExpiryMargin = 600 // 10 minutes
|
||||
)
|
||||
|
||||
// Error 实现error接口
|
||||
func (err OAuthError) Error() string {
|
||||
return err.ErrorDescription
|
||||
}
|
||||
|
||||
// OAuthURL 获取OAuth认证页面URL
|
||||
func (client *client) OAuthURL(ctx context.Context, scope []string) string {
|
||||
query := url.Values{
|
||||
"client_id": {client.policy.BucketName},
|
||||
"scope": {strings.Join(scope, " ")},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {client.policy.Settings.OauthRedirect},
|
||||
"state": {strconv.Itoa(client.policy.ID)},
|
||||
}
|
||||
client.endpoints.oAuthEndpoints.authorize.RawQuery = query.Encode()
|
||||
return client.endpoints.oAuthEndpoints.authorize.String()
|
||||
}
|
||||
|
||||
// getOAuthEndpoint gets OAuth endpoints from API endpoint
|
||||
func getOAuthEndpoint(apiEndpoint string) *oauthEndpoint {
|
||||
base, err := url.Parse(apiEndpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
token *url.URL
|
||||
authorize *url.URL
|
||||
)
|
||||
switch base.Host {
|
||||
//case "login.live.com":
|
||||
// token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
|
||||
// authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
|
||||
case "microsoftgraph.chinacloudapi.cn":
|
||||
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
|
||||
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
|
||||
default:
|
||||
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
|
||||
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
|
||||
}
|
||||
|
||||
return &oauthEndpoint{
|
||||
token: *token,
|
||||
authorize: *authorize,
|
||||
}
|
||||
}
|
||||
|
||||
// Credential 获取token时返回的凭证
|
||||
type Credential struct {
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RefreshedAtUnix int64 `json:"refreshed_at"`
|
||||
|
||||
PolicyID int `json:"policy_id"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
||||
|
||||
func (c Credential) Refresh(ctx context.Context) (credmanager.Credential, error) {
|
||||
if c.RefreshToken == "" {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
dep := dependency.FromContext(ctx)
|
||||
storagePolicyClient := dep.StoragePolicyClient()
|
||||
policy, err := storagePolicyClient.GetPolicyByID(ctx, c.PolicyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get storage policy: %w", err)
|
||||
}
|
||||
|
||||
oauthBase := getOAuthEndpoint(policy.Server)
|
||||
|
||||
newCredential, err := obtainToken(ctx, &obtainTokenArgs{
|
||||
clientId: policy.BucketName,
|
||||
redirect: policy.Settings.OauthRedirect,
|
||||
secret: policy.SecretKey,
|
||||
refreshToken: c.RefreshToken,
|
||||
client: dep.RequestClient(request.WithLogger(dep.Logger())),
|
||||
tokenEndpoint: oauthBase.token.String(),
|
||||
policyID: c.PolicyID,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RefreshToken = newCredential.RefreshToken
|
||||
c.AccessToken = newCredential.AccessToken
|
||||
c.ExpiresIn = newCredential.ExpiresIn
|
||||
c.RefreshedAtUnix = time.Now().Unix()
|
||||
|
||||
// Write refresh token to db
|
||||
if err := storagePolicyClient.UpdateAccessKey(ctx, policy, newCredential.RefreshToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c Credential) Key() string {
|
||||
return CredentialKey(c.PolicyID)
|
||||
}
|
||||
|
||||
func (c Credential) Expiry() time.Time {
|
||||
return time.Unix(c.ExpiresIn-AccessTokenExpiryMargin, 0)
|
||||
}
|
||||
|
||||
func (c Credential) String() string {
|
||||
return c.AccessToken
|
||||
}
|
||||
|
||||
func (c Credential) RefreshedAt() *time.Time {
|
||||
if c.RefreshedAtUnix == 0 {
|
||||
return nil
|
||||
}
|
||||
refreshedAt := time.Unix(c.RefreshedAtUnix, 0)
|
||||
return &refreshedAt
|
||||
}
|
||||
|
||||
// ObtainToken 通过code或refresh_token兑换token
|
||||
func (client *client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
return obtainToken(ctx, &obtainTokenArgs{
|
||||
clientId: client.policy.BucketName,
|
||||
redirect: client.policy.Settings.OauthRedirect,
|
||||
secret: client.policy.SecretKey,
|
||||
code: options.code,
|
||||
refreshToken: options.refreshToken,
|
||||
client: client.httpClient,
|
||||
tokenEndpoint: client.endpoints.oAuthEndpoints.token.String(),
|
||||
policyID: client.policy.ID,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
type obtainTokenArgs struct {
|
||||
clientId string
|
||||
redirect string
|
||||
secret string
|
||||
code string
|
||||
refreshToken string
|
||||
client request.Client
|
||||
tokenEndpoint string
|
||||
policyID int
|
||||
}
|
||||
|
||||
// obtainToken fetch new access token from Microsoft Graph API
|
||||
func obtainToken(ctx context.Context, args *obtainTokenArgs) (*Credential, error) {
|
||||
body := url.Values{
|
||||
"client_id": {args.clientId},
|
||||
"redirect_uri": {args.redirect},
|
||||
"client_secret": {args.secret},
|
||||
}
|
||||
if args.code != "" {
|
||||
body.Add("grant_type", "authorization_code")
|
||||
body.Add("code", args.code)
|
||||
} else {
|
||||
body.Add("grant_type", "refresh_token")
|
||||
body.Add("refresh_token", args.refreshToken)
|
||||
}
|
||||
strBody := body.Encode()
|
||||
|
||||
res := args.client.Request(
|
||||
"POST",
|
||||
args.tokenEndpoint,
|
||||
io.NopCloser(strings.NewReader(strBody)),
|
||||
request.WithHeader(http.Header{
|
||||
"Content-Type": {"application/x-www-form-urlencoded"}},
|
||||
),
|
||||
request.WithContentLength(int64(len(strBody))),
|
||||
request.WithContext(ctx),
|
||||
)
|
||||
if res.Err != nil {
|
||||
return nil, res.Err
|
||||
}
|
||||
|
||||
respBody, err := res.GetResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
errResp OAuthError
|
||||
credential Credential
|
||||
decodeErr error
|
||||
)
|
||||
|
||||
if res.Response.StatusCode != 200 {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||
} else {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &credential)
|
||||
}
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
if errResp.ErrorType != "" {
|
||||
return nil, errResp
|
||||
}
|
||||
|
||||
credential.PolicyID = args.policyID
|
||||
credential.ExpiresIn = time.Now().Unix() + credential.ExpiresIn
|
||||
if args.code != "" {
|
||||
credential.ExpiresIn = time.Now().Unix() - 10
|
||||
}
|
||||
return &credential, nil
|
||||
}
|
||||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *client) UpdateCredential(ctx context.Context) error {
|
||||
newCred, err := client.cred.Obtain(ctx, CredentialKey(client.policy.ID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to obtain token from CredManager: %w", err)
|
||||
}
|
||||
|
||||
client.credential = newCred
|
||||
return nil
|
||||
}
|
||||
|
||||
// RetrieveOneDriveCredentials retrieves OneDrive credentials from DB inventory
|
||||
func RetrieveOneDriveCredentials(ctx context.Context, storagePolicyClient inventory.StoragePolicyClient) ([]credmanager.Credential, error) {
|
||||
odPolicies, err := storagePolicyClient.ListPolicyByType(ctx, types.PolicyTypeOd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list OneDrive policies: %w", err)
|
||||
}
|
||||
|
||||
return lo.Map(odPolicies, func(item *ent.StoragePolicy, index int) credmanager.Credential {
|
||||
return &Credential{
|
||||
PolicyID: item.ID,
|
||||
ExpiresIn: 0,
|
||||
RefreshToken: item.AccessKey,
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
func CredentialKey(policyId int) string {
|
||||
return fmt.Sprintf("cred_od_%d", policyId)
|
||||
}
|
||||
247
pkg/filemanager/driver/onedrive/onedrive.go
Normal file
247
pkg/filemanager/driver/onedrive/onedrive.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Driver OneDrive 适配器
|
||||
type Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
client Client
|
||||
settings setting.Provider
|
||||
config conf.ConfigProvider
|
||||
l logging.Logger
|
||||
chunkSize int64
|
||||
}
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
const (
|
||||
streamSaverParam = "stream_saver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
boolset.Sets(map[driver.HandlerCapability]bool{
|
||||
driver.HandlerCapabilityUploadSentinelRequired: true,
|
||||
}, features)
|
||||
}
|
||||
|
||||
// NewDriver 从存储策略初始化新的Driver实例
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, cred credmanager.CredManager) (*Driver, error) {
|
||||
chunkSize := policy.Settings.ChunkSize
|
||||
if policy.Settings.ChunkSize == 0 {
|
||||
chunkSize = 50 << 20 // 50MB
|
||||
}
|
||||
|
||||
c := NewClient(policy, request.NewClient(config, request.WithLogger(l)), cred, l, settings, chunkSize)
|
||||
|
||||
return &Driver{
|
||||
policy: policy,
|
||||
client: c,
|
||||
settings: settings,
|
||||
l: l,
|
||||
config: config,
|
||||
chunkSize: chunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//// List 列取项目
|
||||
//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// base = strings.TrimPrefix(base, "/")
|
||||
// // 列取子项目
|
||||
// objects, _ := handler.client.ListChildren(ctx, base)
|
||||
//
|
||||
// // 获取真实的列取起始根目录
|
||||
// rootPath := base
|
||||
// if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok {
|
||||
// rootPath = realBase
|
||||
// } else {
|
||||
// ctx = context.WithValue(ctx, fsctx.PathCtx, base)
|
||||
// }
|
||||
//
|
||||
// // 整理结果
|
||||
// res := make([]response.Object, 0, len(objects))
|
||||
// for _, object := range objects {
|
||||
// source := path.Join(base, object.Name)
|
||||
// rel, err := filepath.Rel(rootPath, source)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: object.Name,
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Source: source,
|
||||
// Size: uint64(object.Size),
|
||||
// IsDir: object.Folder != nil,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// // 递归列取子目录
|
||||
// if recursive {
|
||||
// for _, object := range objects {
|
||||
// if object.Folder != nil {
|
||||
// sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive)
|
||||
// res = append(res, sub...)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//}
|
||||
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
return handler.client.Upload(ctx, file)
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
return handler.client.BatchDelete(ctx, files)
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
res, err := handler.client.GetThumbURL(ctx, e.Source())
|
||||
if err != nil {
|
||||
var apiErr *RespError
|
||||
if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) {
|
||||
// OneDrive cannot generate thumbnail for this file
|
||||
return "", fmt.Errorf("thumb not supported in OneDrive: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
// 缓存不存在,重新获取
|
||||
res, err := handler.client.Meta(ctx, "", e.Source())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if args.IsDownload && handler.policy.Settings.StreamSaver {
|
||||
downloadUrl := res.DownloadURL + "&" + streamSaverParam + "=" + url.QueryEscape(args.DisplayName)
|
||||
return downloadUrl, nil
|
||||
}
|
||||
|
||||
return res.DownloadURL, nil
|
||||
}
|
||||
|
||||
// Token 获取上传会话URL
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
// 生成回调地址
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeOd, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
|
||||
uploadURL, err := handler.client.CreateUploadSession(ctx, file.Props.SavePath, WithConflictBehavior("fail"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 监控回调及上传
|
||||
//go handler.client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl)
|
||||
|
||||
uploadSession.ChunkSize = handler.chunkSize
|
||||
uploadSession.UploadURL = uploadURL
|
||||
return &fs.UploadCredential{
|
||||
ChunkSize: handler.chunkSize,
|
||||
UploadURLs: []string{uploadURL},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
err := handler.client.DeleteUploadSession(ctx, uploadSession.UploadURL)
|
||||
// Create empty placeholder file to stop upload
|
||||
if err == nil {
|
||||
_, err := handler.client.SimpleUpload(ctx, uploadSession.Props.SavePath, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
handler.l.Warning("Failed to create placeholder file %q:%s", uploadSession.Props.SavePath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
if session.SentinelTaskID == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure uploaded file size is correct
|
||||
res, err := handler.client.Meta(ctx, "", session.Props.SavePath)
|
||||
if err != nil {
|
||||
// Create empty placeholder file to stop further upload
|
||||
|
||||
return fmt.Errorf("failed to get uploaded file size: %w", err)
|
||||
}
|
||||
|
||||
isSharePoint := strings.Contains(handler.policy.Settings.OdDriver, "sharepoint.com") ||
|
||||
strings.Contains(handler.policy.Settings.OdDriver, "sharepoint.cn")
|
||||
sizeMismatch := res.Size != session.Props.Size
|
||||
// SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容
|
||||
// See: https://github.com/OneDrive/onedrive-api-docs/issues/935
|
||||
if isSharePoint && sizeMismatch && (res.Size > session.Props.Size) && (res.Size-session.Props.Size <= 1048576) {
|
||||
sizeMismatch = false
|
||||
}
|
||||
|
||||
if sizeMismatch {
|
||||
return serializer.NewError(
|
||||
serializer.CodeMetaMismatch,
|
||||
fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.Size),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
ThumbSupportedExts: handler.policy.Settings.ThumbExts,
|
||||
ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts,
|
||||
ThumbMaxSize: handler.policy.Settings.ThumbMaxSize,
|
||||
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
|
||||
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
@@ -27,7 +27,7 @@ type UploadSessionResponse struct {
|
||||
// FileInfo 文件元信息
|
||||
type FileInfo struct {
|
||||
Name string `json:"name"`
|
||||
Size uint64 `json:"size"`
|
||||
Size int64 `json:"size"`
|
||||
Image imageInfo `json:"image"`
|
||||
ParentReference parentReference `json:"parentReference"`
|
||||
DownloadURL string `json:"@microsoft.graph.downloadUrl"`
|
||||
@@ -104,16 +104,6 @@ type oauthEndpoint struct {
|
||||
authorize url.URL
|
||||
}
|
||||
|
||||
// Credential 获取token时返回的凭证
|
||||
type Credential struct {
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// OAuthError OAuth相关接口的错误响应
|
||||
type OAuthError struct {
|
||||
ErrorType string `json:"error"`
|
||||
@@ -10,39 +10,44 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
const (
|
||||
pubKeyHeader = "x-oss-pub-key-url"
|
||||
pubKeyPrefix = "http://gosspublic.alicdn.com/"
|
||||
pubKeyPrefixHttps = "https://gosspublic.alicdn.com/"
|
||||
pubKeyCacheKey = "oss_public_key"
|
||||
)
|
||||
|
||||
// GetPublicKey 从回调请求或缓存中获取OSS的回调签名公钥
|
||||
func GetPublicKey(r *http.Request) ([]byte, error) {
|
||||
func GetPublicKey(r *http.Request, kv cache.Driver, client request.Client) ([]byte, error) {
|
||||
var pubKey []byte
|
||||
|
||||
// 尝试从缓存中获取
|
||||
pub, exist := cache.Get("oss_public_key")
|
||||
pub, exist := kv.Get(pubKeyCacheKey)
|
||||
if exist {
|
||||
return pub.([]byte), nil
|
||||
}
|
||||
|
||||
// 从请求中获取
|
||||
pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get("x-oss-pub-key-url"))
|
||||
pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get(pubKeyHeader))
|
||||
if err != nil {
|
||||
return pubKey, err
|
||||
}
|
||||
|
||||
// 确保这个 public key 是由 OSS 颁发的
|
||||
if !strings.HasPrefix(string(pubURL), "http://gosspublic.alicdn.com/") &&
|
||||
!strings.HasPrefix(string(pubURL), "https://gosspublic.alicdn.com/") {
|
||||
if !strings.HasPrefix(string(pubURL), pubKeyPrefix) &&
|
||||
!strings.HasPrefix(string(pubURL), pubKeyPrefixHttps) {
|
||||
return pubKey, errors.New("public key url invalid")
|
||||
}
|
||||
|
||||
// 获取公钥
|
||||
client := request.NewClient()
|
||||
body, err := client.Request("GET", string(pubURL), nil).
|
||||
CheckHTTPResponse(200).
|
||||
GetResponse()
|
||||
@@ -51,7 +56,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) {
|
||||
}
|
||||
|
||||
// 写入缓存
|
||||
_ = cache.Set("oss_public_key", []byte(body), 86400*7)
|
||||
_ = kv.Set(pubKeyCacheKey, []byte(body), 86400*7)
|
||||
|
||||
return []byte(body), nil
|
||||
}
|
||||
@@ -60,12 +65,12 @@ func getRequestMD5(r *http.Request) ([]byte, error) {
|
||||
var byteMD5 []byte
|
||||
|
||||
// 获取请求正文
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return byteMD5, err
|
||||
}
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
strURLPathDecode, err := url.PathUnescape(r.URL.Path)
|
||||
if err != nil {
|
||||
@@ -81,8 +86,8 @@ func getRequestMD5(r *http.Request) ([]byte, error) {
|
||||
}
|
||||
|
||||
// VerifyCallbackSignature 验证OSS回调请求
|
||||
func VerifyCallbackSignature(r *http.Request) error {
|
||||
bytePublicKey, err := GetPublicKey(r)
|
||||
func VerifyCallbackSignature(r *http.Request, kv cache.Driver, client request.Client) error {
|
||||
bytePublicKey, err := GetPublicKey(r, kv, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
359
pkg/filemanager/driver/oss/media.go
Normal file
359
pkg/filemanager/driver/oss/media.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package oss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/mediameta"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/samber/lo"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
imageInfoProcess = "image/info"
|
||||
videoInfoProcess = "video/info"
|
||||
audioInfoProcess = "audio/info"
|
||||
mediaInfoTTL = time.Duration(10) * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
supportedImageExt = []string{"jpg", "jpeg", "png", "gif", "bmp", "webp", "tiff", "heic", "heif"}
|
||||
supportedAudioExt = []string{"mp3", "wav", "flac", "aac", "m4a", "ogg", "wma", "ape", "alac", "amr", "opus"}
|
||||
supportedVideoExt = []string{"mp4", "mkv", "avi", "mov", "flv", "wmv", "rmvb", "webm", "3gp", "mpg", "mpeg", "m4v", "ts", "m3u8", "vob", "f4v", "rm", "asf", "divx", "ogv", "dat", "mts", "m2ts", "swf", "avi", "3g2", "m2v", "m4p", "m4b", "m4r", "m4v", "m4a"}
|
||||
)
|
||||
|
||||
type (
|
||||
ImageProp struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
ImageInfo map[string]ImageProp
|
||||
|
||||
Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Text string `xml:",chardata"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
HostId string `xml:"HostId"`
|
||||
EC string `xml:"EC"`
|
||||
RecommendDoc string `xml:"RecommendDoc"`
|
||||
}
|
||||
|
||||
StreamMediaInfo struct {
|
||||
RequestID string `json:"RequestId"`
|
||||
Language string `json:"Language"`
|
||||
Title string `json:"Title"`
|
||||
VideoStreams []VideoStream `json:"VideoStreams"`
|
||||
AudioStreams []AudioStream `json:"AudioStreams"`
|
||||
Subtitles []Subtitle `json:"Subtitles"`
|
||||
StreamCount int64 `json:"StreamCount"`
|
||||
ProgramCount int64 `json:"ProgramCount"`
|
||||
FormatName string `json:"FormatName"`
|
||||
FormatLongName string `json:"FormatLongName"`
|
||||
Size int64 `json:"Size"`
|
||||
StartTime float64 `json:"StartTime"`
|
||||
Bitrate int64 `json:"Bitrate"`
|
||||
Artist string `json:"Artist"`
|
||||
AlbumArtist string `json:"AlbumArtist"`
|
||||
Composer string `json:"Composer"`
|
||||
Performer string `json:"Performer"`
|
||||
Album string `json:"Album"`
|
||||
Duration float64 `json:"Duration"`
|
||||
ProduceTime string `json:"ProduceTime"`
|
||||
LatLong string `json:"LatLong"`
|
||||
VideoWidth int64 `json:"VideoWidth"`
|
||||
VideoHeight int64 `json:"VideoHeight"`
|
||||
Addresses []Address `json:"Addresses"`
|
||||
}
|
||||
|
||||
Address struct {
|
||||
Language string `json:"Language"`
|
||||
AddressLine string `json:"AddressLine"`
|
||||
Country string `json:"Country"`
|
||||
Province string `json:"Province"`
|
||||
City string `json:"City"`
|
||||
District string `json:"District"`
|
||||
Township string `json:"Township"`
|
||||
}
|
||||
|
||||
AudioStream struct {
|
||||
Index int `json:"Index"`
|
||||
Language string `json:"Language"`
|
||||
CodecName string `json:"CodecName"`
|
||||
CodecLongName string `json:"CodecLongName"`
|
||||
CodecTimeBase string `json:"CodecTimeBase"`
|
||||
CodecTagString string `json:"CodecTagString"`
|
||||
CodecTag string `json:"CodecTag"`
|
||||
TimeBase string `json:"TimeBase"`
|
||||
StartTime float64 `json:"StartTime"`
|
||||
Duration float64 `json:"Duration"`
|
||||
Bitrate int64 `json:"Bitrate"`
|
||||
FrameCount int64 `json:"FrameCount"`
|
||||
Lyric string `json:"Lyric"`
|
||||
SampleFormat string `json:"SampleFormat"`
|
||||
SampleRate int64 `json:"SampleRate"`
|
||||
Channels int64 `json:"Channels"`
|
||||
ChannelLayout string `json:"ChannelLayout"`
|
||||
}
|
||||
|
||||
Subtitle struct {
|
||||
Index int64 `json:"Index"`
|
||||
Language string `json:"Language"`
|
||||
CodecName string `json:"CodecName"`
|
||||
CodecLongName string `json:"CodecLongName"`
|
||||
CodecTagString string `json:"CodecTagString"`
|
||||
CodecTag string `json:"CodecTag"`
|
||||
StartTime float64 `json:"StartTime"`
|
||||
Duration float64 `json:"Duration"`
|
||||
Bitrate int64 `json:"Bitrate"`
|
||||
Content string `json:"Content"`
|
||||
Width int64 `json:"Width"`
|
||||
Height int64 `json:"Height"`
|
||||
}
|
||||
|
||||
VideoStream struct {
|
||||
Index int `json:"Index"`
|
||||
Language string `json:"Language"`
|
||||
CodecName string `json:"CodecName"`
|
||||
CodecLongName string `json:"CodecLongName"`
|
||||
Profile string `json:"Profile"`
|
||||
CodecTimeBase string `json:"CodecTimeBase"`
|
||||
CodecTagString string `json:"CodecTagString"`
|
||||
CodecTag string `json:"CodecTag"`
|
||||
Width int `json:"Width"`
|
||||
Height int `json:"Height"`
|
||||
HasBFrames int `json:"HasBFrames"`
|
||||
SampleAspectRatio string `json:"SampleAspectRatio"`
|
||||
DisplayAspectRatio string `json:"DisplayAspectRatio"`
|
||||
PixelFormat string `json:"PixelFormat"`
|
||||
Level int `json:"Level"`
|
||||
FrameRate string `json:"FrameRate"`
|
||||
AverageFrameRate string `json:"AverageFrameRate"`
|
||||
TimeBase string `json:"TimeBase"`
|
||||
StartTime float64 `json:"StartTime"`
|
||||
Duration float64 `json:"Duration"`
|
||||
Bitrate int64 `json:"Bitrate"`
|
||||
FrameCount int64 `json:"FrameCount"`
|
||||
Rotate string `json:"Rotate"`
|
||||
BitDepth int `json:"BitDepth"`
|
||||
ColorSpace string `json:"ColorSpace"`
|
||||
ColorRange string `json:"ColorRange"`
|
||||
ColorTransfer string `json:"ColorTransfer"`
|
||||
ColorPrimaries string `json:"ColorPrimaries"`
|
||||
}
|
||||
)
|
||||
|
||||
func (handler *Driver) extractIMMMeta(ctx context.Context, path, category string) ([]driver.MediaMeta, error) {
|
||||
resp, err := handler.extractMediaInfo(ctx, path, category, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var info StreamMediaInfo
|
||||
if err := json.Unmarshal([]byte(resp), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
streams := lo.Map(info.VideoStreams, func(stream VideoStream, index int) mediameta.Stream {
|
||||
bitrate := ""
|
||||
if stream.Bitrate != 0 {
|
||||
bitrate = strconv.FormatInt(stream.Bitrate, 10)
|
||||
}
|
||||
return mediameta.Stream{
|
||||
Index: stream.Index,
|
||||
CodecName: stream.CodecName,
|
||||
CodecLongName: stream.CodecLongName,
|
||||
CodecType: "video",
|
||||
Width: stream.Width,
|
||||
Height: stream.Height,
|
||||
Duration: strconv.FormatFloat(stream.Duration, 'f', -1, 64),
|
||||
Bitrate: bitrate,
|
||||
}
|
||||
})
|
||||
streams = append(streams, lo.Map(info.AudioStreams, func(stream AudioStream, index int) mediameta.Stream {
|
||||
bitrate := ""
|
||||
if stream.Bitrate != 0 {
|
||||
bitrate = strconv.FormatInt(stream.Bitrate, 10)
|
||||
}
|
||||
return mediameta.Stream{
|
||||
Index: stream.Index,
|
||||
CodecName: stream.CodecName,
|
||||
CodecLongName: stream.CodecLongName,
|
||||
CodecType: "audio",
|
||||
Duration: strconv.FormatFloat(stream.Duration, 'f', -1, 64),
|
||||
Bitrate: bitrate,
|
||||
}
|
||||
})...)
|
||||
|
||||
metas := make([]driver.MediaMeta, 0)
|
||||
metas = append(metas, mediameta.ProbeMetaTransform(&mediameta.FFProbeMeta{
|
||||
Format: &mediameta.Format{
|
||||
FormatName: info.FormatName,
|
||||
FormatLongName: info.FormatLongName,
|
||||
Duration: strconv.FormatFloat(info.Duration, 'f', -1, 64),
|
||||
Bitrate: strconv.FormatInt(info.Bitrate, 10),
|
||||
},
|
||||
Streams: streams,
|
||||
})...)
|
||||
|
||||
if info.Artist != "" {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.MusicArtist,
|
||||
Value: info.Artist,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
if info.AlbumArtist != "" {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.MusicAlbumArtists,
|
||||
Value: info.AlbumArtist,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
if info.Composer != "" {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.MusicComposer,
|
||||
Value: info.Composer,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
if info.Album != "" {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.MusicAlbum,
|
||||
Value: info.Album,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) {
|
||||
resp, err := handler.extractMediaInfo(ctx, path, imageInfoProcess, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var imageInfo ImageInfo
|
||||
if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
metas := make([]driver.MediaMeta, 0)
|
||||
exifMap := lo.MapEntries(imageInfo, func(key string, value ImageProp) (string, string) {
|
||||
return key, value.Value
|
||||
})
|
||||
metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...)
|
||||
metas = append(metas, parseGpsInfo(imageInfo)...)
|
||||
for i := 0; i < len(metas); i++ {
|
||||
metas[i].Type = driver.MetaTypeExif
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
// extractMediaInfo Sends API calls to OSS IMM service to extract media info.
|
||||
func (handler *Driver) extractMediaInfo(ctx context.Context, path string, category string, forceSign bool) (string, error) {
|
||||
mediaOption := []oss.Option{oss.Process(category)}
|
||||
mediaInfoExpire := time.Now().Add(mediaInfoTTL)
|
||||
thumbURL, err := handler.signSourceURL(
|
||||
ctx,
|
||||
path,
|
||||
&mediaInfoExpire,
|
||||
mediaOption,
|
||||
forceSign,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign media info url: %w", err)
|
||||
}
|
||||
|
||||
resp, err := handler.httpClient.
|
||||
Request(http.MethodGet, thumbURL, nil, request.WithContext(ctx)).
|
||||
CheckHTTPResponse(http.StatusOK).
|
||||
GetResponseIgnoreErr()
|
||||
if err != nil {
|
||||
return "", handleOssError(resp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseGpsInfo(imageInfo ImageInfo) []driver.MediaMeta {
|
||||
latitude := imageInfo["GPSLatitude"] // 31deg 16.26808'
|
||||
longitude := imageInfo["GPSLongitude"] // 120deg 42.91039'
|
||||
latRef := imageInfo["GPSLatitudeRef"] // North
|
||||
lonRef := imageInfo["GPSLongitudeRef"] // East
|
||||
|
||||
// Make sure all value exist in map
|
||||
if latitude.Value == "" || longitude.Value == "" || latRef.Value == "" || lonRef.Value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lat := parseRawGPS(latitude.Value, latRef.Value)
|
||||
lon := parseRawGPS(longitude.Value, lonRef.Value)
|
||||
if !math.IsNaN(lat) && !math.IsNaN(lon) {
|
||||
lat, lng := mediameta.NormalizeGPS(lat, lon)
|
||||
return []driver.MediaMeta{{
|
||||
Key: mediameta.GpsLat,
|
||||
Value: fmt.Sprintf("%f", lat),
|
||||
}, {
|
||||
Key: mediameta.GpsLng,
|
||||
Value: fmt.Sprintf("%f", lng),
|
||||
}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRawGPS(gpsStr string, ref string) float64 {
|
||||
elem := strings.Split(gpsStr, " ")
|
||||
if len(elem) < 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var (
|
||||
deg float64
|
||||
minutes float64
|
||||
seconds float64
|
||||
)
|
||||
|
||||
deg, _ = strconv.ParseFloat(strings.TrimSuffix(elem[0], "deg"), 64)
|
||||
if len(elem) >= 2 {
|
||||
minutes, _ = strconv.ParseFloat(strings.TrimSuffix(elem[1], "'"), 64)
|
||||
}
|
||||
if len(elem) >= 3 {
|
||||
seconds, _ = strconv.ParseFloat(strings.TrimSuffix(elem[2], "\""), 64)
|
||||
}
|
||||
|
||||
decimal := deg + minutes/60.0 + seconds/3600.0
|
||||
|
||||
if ref == "South" || ref == "West" {
|
||||
return -decimal
|
||||
}
|
||||
|
||||
return decimal
|
||||
}
|
||||
|
||||
func handleOssError(resp string, originErr error) error {
|
||||
if resp == "" {
|
||||
return originErr
|
||||
}
|
||||
|
||||
var err Error
|
||||
if err := xml.Unmarshal([]byte(resp), &err); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal oss error: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("oss error: %s", err.Message)
|
||||
}
|
||||
548
pkg/filemanager/driver/oss/oss.go
Normal file
548
pkg/filemanager/driver/oss/oss.go
Normal file
@@ -0,0 +1,548 @@
|
||||
package oss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// UploadPolicy 阿里云OSS上传策略
|
||||
type UploadPolicy struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions []interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
// CallbackPolicy 回调策略
|
||||
type CallbackPolicy struct {
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
CallbackBody string `json:"callbackBody"`
|
||||
CallbackBodyType string `json:"callbackBodyType"`
|
||||
CallbackSNI bool `json:"callbackSNI"`
|
||||
}
|
||||
|
||||
// Driver 阿里云OSS策略适配器
|
||||
type Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
|
||||
client *oss.Client
|
||||
bucket *oss.Bucket
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
mime mime.MimeDetector
|
||||
httpClient request.Client
|
||||
|
||||
chunkSize int64
|
||||
}
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
uploadIdParam = "uploadId"
|
||||
partNumberParam = "partNumber"
|
||||
callbackParam = "callback"
|
||||
completeAllHeader = "x-oss-complete-all"
|
||||
maxDeleteBatch = 1000
|
||||
|
||||
// MultiPartUploadThreshold 服务端使用分片上传的阈值
|
||||
MultiPartUploadThreshold int64 = 5 * (1 << 30) // 5GB
|
||||
)
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) {
|
||||
chunkSize := policy.Settings.ChunkSize
|
||||
if policy.Settings.ChunkSize == 0 {
|
||||
chunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
driver := &Driver{
|
||||
policy: policy,
|
||||
settings: settings,
|
||||
chunkSize: chunkSize,
|
||||
config: config,
|
||||
l: l,
|
||||
mime: mime,
|
||||
httpClient: request.NewClient(config, request.WithLogger(l)),
|
||||
}
|
||||
|
||||
return driver, driver.InitOSSClient(false)
|
||||
}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (handler *Driver) CORS() error {
|
||||
return handler.client.SetBucketCORS(handler.policy.BucketName, []oss.CORSRule{
|
||||
{
|
||||
AllowedOrigin: []string{"*"},
|
||||
AllowedMethod: []string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
},
|
||||
ExposeHeader: []string{},
|
||||
AllowedHeader: []string{"*"},
|
||||
MaxAgeSeconds: 3600,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// InitOSSClient 初始化OSS鉴权客户端
|
||||
func (handler *Driver) InitOSSClient(forceUsePublicEndpoint bool) error {
|
||||
if handler.policy == nil {
|
||||
return errors.New("empty policy")
|
||||
}
|
||||
|
||||
opt := make([]oss.ClientOption, 0)
|
||||
|
||||
// 决定是否使用内网 Endpoint
|
||||
endpoint := handler.policy.Server
|
||||
if handler.policy.Settings.ServerSideEndpoint != "" && !forceUsePublicEndpoint {
|
||||
endpoint = handler.policy.Settings.ServerSideEndpoint
|
||||
} else if handler.policy.Settings.UseCname {
|
||||
opt = append(opt, oss.UseCname(true))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
|
||||
endpoint = "https://" + endpoint
|
||||
}
|
||||
|
||||
// 初始化客户端
|
||||
client, err := oss.New(endpoint, handler.policy.AccessKey, handler.policy.SecretKey, opt...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handler.client = client
|
||||
|
||||
// 初始化存储桶
|
||||
bucket, err := client.Bucket(handler.policy.BucketName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handler.bucket = bucket
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//// List 列出OSS上的文件
|
||||
//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// // 列取文件
|
||||
// base = strings.TrimPrefix(base, "/")
|
||||
// if base != "" {
|
||||
// base += "/"
|
||||
// }
|
||||
//
|
||||
// var (
|
||||
// delimiter string
|
||||
// marker string
|
||||
// objects []oss.ObjectProperties
|
||||
// commons []string
|
||||
// )
|
||||
// if !recursive {
|
||||
// delimiter = "/"
|
||||
// }
|
||||
//
|
||||
// for {
|
||||
// subRes, err := handler.bucket.ListObjects(oss.Marker(marker), oss.Prefix(base),
|
||||
// oss.MaxKeys(1000), oss.Delimiter(delimiter))
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// objects = append(objects, subRes.Objects...)
|
||||
// commons = append(commons, subRes.CommonPrefixes...)
|
||||
// marker = subRes.NextMarker
|
||||
// if marker == "" {
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 处理列取结果
|
||||
// res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
// // 处理目录
|
||||
// for _, object := range commons {
|
||||
// rel, err := filepath.Rel(base, object)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object),
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: 0,
|
||||
// IsDir: true,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
// // 处理文件
|
||||
// for _, object := range objects {
|
||||
// rel, err := filepath.Rel(base, object.Key)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object.Key),
|
||||
// Source: object.Key,
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: uint64(object.Size),
|
||||
// IsDir: false,
|
||||
// LastModify: object.LastModified,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
// 凭证有效期
|
||||
credentialTTL := handler.settings.UploadSessionTTL(ctx)
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
options := []oss.Option{
|
||||
oss.WithContext(ctx),
|
||||
oss.Expires(time.Now().Add(credentialTTL * time.Second)),
|
||||
oss.ForbidOverWrite(!overwrite),
|
||||
oss.ContentType(mimeType),
|
||||
}
|
||||
|
||||
// 小文件直接上传
|
||||
if file.Props.Size < MultiPartUploadThreshold {
|
||||
return handler.bucket.PutObject(file.Props.SavePath, file, options...)
|
||||
}
|
||||
|
||||
// 超过阈值时使用分片上传
|
||||
imur, err := handler.bucket.InitiateMultipartUpload(file.Props.SavePath, options...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initiate multipart upload: %w", err)
|
||||
}
|
||||
|
||||
parts := make([]oss.UploadPart, 0)
|
||||
|
||||
chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{
|
||||
Max: handler.settings.ChunkRetryLimit(ctx),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, handler.settings.UseChunkBuffer(ctx), handler.l, handler.settings.TempPath(ctx))
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
part, err := handler.bucket.UploadPart(imur, content, current.Length(), current.Index()+1, oss.WithContext(ctx))
|
||||
if err == nil {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
handler.cancelUpload(imur)
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = handler.bucket.CompleteMultipartUpload(imur, parts, oss.ForbidOverWrite(!overwrite), oss.WithContext(ctx))
|
||||
if err != nil {
|
||||
handler.cancelUpload(imur)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
groups := lo.Chunk(files, maxDeleteBatch)
|
||||
failed := make([]string, 0)
|
||||
var lastError error
|
||||
for index, group := range groups {
|
||||
handler.l.Debug("Process delete group #%d: %v", index, group)
|
||||
// 删除文件
|
||||
delRes, err := handler.bucket.DeleteObjects(group)
|
||||
if err != nil {
|
||||
failed = append(failed, group...)
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
|
||||
// 统计未删除的文件
|
||||
failed = append(failed, util.SliceDifference(files, delRes.DeletedObjects)...)
|
||||
}
|
||||
|
||||
if len(failed) > 0 && lastError == nil {
|
||||
lastError = fmt.Errorf("failed to delete files: %v", failed)
|
||||
}
|
||||
|
||||
return failed, lastError
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
usePublicEndpoint := true
|
||||
if forceUsePublicEndpoint, ok := ctx.Value(driver.ForceUsePublicEndpointCtx{}).(bool); ok {
|
||||
usePublicEndpoint = forceUsePublicEndpoint
|
||||
}
|
||||
|
||||
// 初始化客户端
|
||||
if err := handler.InitOSSClient(usePublicEndpoint); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
w, h := handler.settings.ThumbSize(ctx)
|
||||
thumbParam := fmt.Sprintf("image/resize,m_lfit,h_%d,w_%d", h, w)
|
||||
thumbOption := []oss.Option{oss.Process(thumbParam)}
|
||||
thumbURL, err := handler.signSourceURL(
|
||||
ctx,
|
||||
e.Source(),
|
||||
expire,
|
||||
thumbOption,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return thumbURL, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
// 初始化客户端
|
||||
usePublicEndpoint := true
|
||||
if forceUsePublicEndpoint, ok := ctx.Value(driver.ForceUsePublicEndpointCtx{}).(bool); ok {
|
||||
usePublicEndpoint = forceUsePublicEndpoint
|
||||
}
|
||||
if err := handler.InitOSSClient(usePublicEndpoint); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 添加各项设置
|
||||
var signOptions = make([]oss.Option, 0, 2)
|
||||
if args.IsDownload {
|
||||
encodedFilename := url.PathEscape(args.DisplayName)
|
||||
signOptions = append(signOptions, oss.ResponseContentDisposition(fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`,
|
||||
encodedFilename, encodedFilename)))
|
||||
}
|
||||
if args.Speed > 0 {
|
||||
// Byte 转换为 bit
|
||||
args.Speed *= 8
|
||||
|
||||
// OSS对速度值有范围限制
|
||||
if args.Speed < 819200 {
|
||||
args.Speed = 819200
|
||||
}
|
||||
if args.Speed > 838860800 {
|
||||
args.Speed = 838860800
|
||||
}
|
||||
signOptions = append(signOptions, oss.TrafficLimitParam(args.Speed))
|
||||
}
|
||||
|
||||
return handler.signSourceURL(ctx, e.Source(), args.Expire, signOptions, false)
|
||||
}
|
||||
|
||||
func (handler *Driver) signSourceURL(ctx context.Context, path string, expire *time.Time, options []oss.Option, forceSign bool) (string, error) {
|
||||
ttl := int64(86400 * 365 * 20)
|
||||
if expire != nil {
|
||||
ttl = int64(time.Until(*expire).Seconds())
|
||||
}
|
||||
|
||||
signedURL, err := handler.bucket.SignURL(path, oss.HTTPGet, ttl, options...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将最终生成的签名URL域名换成用户自定义的加速域名(如果有)
|
||||
finalURL, err := url.Parse(signedURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 公有空间替换掉Key及不支持的头
|
||||
if !handler.policy.IsPrivate && !forceSign {
|
||||
query := finalURL.Query()
|
||||
query.Del("OSSAccessKeyId")
|
||||
query.Del("Signature")
|
||||
query.Del("response-content-disposition")
|
||||
query.Del("x-oss-traffic-limit")
|
||||
finalURL.RawQuery = query.Encode()
|
||||
}
|
||||
return finalURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
// 初始化客户端
|
||||
if err := handler.InitOSSClient(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成回调地址
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
// 在从机端创建上传会话
|
||||
uploadSession.ChunkSize = handler.chunkSize
|
||||
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeOss, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
|
||||
// 回调策略
|
||||
callbackPolicy := CallbackPolicy{
|
||||
CallbackURL: uploadSession.Callback,
|
||||
CallbackBody: `{"name":${x:fname},"source_name":${object},"size":${size},"pic_info":"${imageInfo.width},${imageInfo.height}"}`,
|
||||
CallbackBodyType: "application/json",
|
||||
CallbackSNI: true,
|
||||
}
|
||||
callbackPolicyJSON, err := json.Marshal(callbackPolicy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode callback policy: %w", err)
|
||||
}
|
||||
callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON)
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
// 初始化分片上传
|
||||
options := []oss.Option{
|
||||
oss.WithContext(ctx),
|
||||
oss.Expires(uploadSession.Props.ExpireAt),
|
||||
oss.ForbidOverWrite(true),
|
||||
oss.ContentType(mimeType),
|
||||
}
|
||||
imur, err := handler.bucket.InitiateMultipartUpload(file.Props.SavePath, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize multipart upload: %w", err)
|
||||
}
|
||||
uploadSession.UploadID = imur.UploadID
|
||||
|
||||
// 为每个分片签名上传 URL
|
||||
chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{}, false, handler.l, "")
|
||||
urls := make([]string, chunks.Num())
|
||||
ttl := int64(time.Until(uploadSession.Props.ExpireAt).Seconds())
|
||||
for chunks.Next() {
|
||||
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
|
||||
signedURL, err := handler.bucket.SignURL(file.Props.SavePath, oss.HTTPPut,
|
||||
ttl,
|
||||
oss.AddParam(partNumberParam, strconv.Itoa(c.Index()+1)),
|
||||
oss.AddParam(uploadIdParam, imur.UploadID),
|
||||
oss.ContentType("application/octet-stream"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
urls[c.Index()] = signedURL
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 签名完成分片上传的URL
|
||||
completeURL, err := handler.bucket.SignURL(file.Props.SavePath, oss.HTTPPost, ttl,
|
||||
oss.ContentType("application/octet-stream"),
|
||||
oss.AddParam(uploadIdParam, imur.UploadID),
|
||||
oss.Expires(time.Now().Add(time.Duration(ttl)*time.Second)),
|
||||
oss.SetHeader(completeAllHeader, "yes"),
|
||||
oss.ForbidOverWrite(true),
|
||||
oss.AddParam(callbackParam, callbackPolicyEncoded))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fs.UploadCredential{
|
||||
UploadID: imur.UploadID,
|
||||
UploadURLs: urls,
|
||||
CompleteURL: completeURL,
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: handler.chunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
return handler.bucket.AbortMultipartUpload(oss.InitiateMultipartUploadResult{UploadID: uploadSession.UploadID, Key: uploadSession.Props.SavePath}, oss.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
mediaMetaExts := handler.policy.Settings.MediaMetaExts
|
||||
if !handler.policy.Settings.NativeMediaProcessing {
|
||||
mediaMetaExts = nil
|
||||
}
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaSupportedExts: mediaMetaExts,
|
||||
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbSupportedExts: handler.policy.Settings.ThumbExts,
|
||||
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
|
||||
ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts,
|
||||
ThumbMaxSize: handler.policy.Settings.ThumbMaxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
if util.ContainsString(supportedImageExt, ext) {
|
||||
return handler.extractImageMeta(ctx, path)
|
||||
}
|
||||
|
||||
if util.ContainsString(supportedVideoExt, ext) {
|
||||
return handler.extractIMMMeta(ctx, path, videoInfoProcess)
|
||||
}
|
||||
|
||||
if util.ContainsString(supportedAudioExt, ext) {
|
||||
return handler.extractIMMMeta(ctx, path, audioInfoProcess)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported media type in oss: %s", ext)
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (handler *Driver) cancelUpload(imur oss.InitiateMultipartUploadResult) {
|
||||
if err := handler.bucket.AbortMultipartUpload(imur); err != nil {
|
||||
handler.l.Warning("failed to abort multipart upload: %s", err)
|
||||
}
|
||||
}
|
||||
183
pkg/filemanager/driver/qiniu/media.go
Normal file
183
pkg/filemanager/driver/qiniu/media.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package qiniu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/mediameta"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/samber/lo"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
exifParam = "exif"
|
||||
avInfoParam = "avinfo"
|
||||
mediaInfoTTL = time.Duration(10) * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
supportedImageExt = []string{"jpg", "jpeg", "png", "gif", "bmp", "webp", "tiff"}
|
||||
)
|
||||
|
||||
type (
|
||||
ImageProp struct {
|
||||
Value string `json:"val"`
|
||||
}
|
||||
ImageInfo map[string]ImageProp
|
||||
QiniuMediaError struct {
|
||||
Error string `json:"error"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
)
|
||||
|
||||
func (handler *Driver) extractAvMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) {
|
||||
resp, err := handler.extractMediaInfo(ctx, path, avInfoParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var avInfo *mediameta.FFProbeMeta
|
||||
if err := json.Unmarshal([]byte(resp), &avInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
metas := mediameta.ProbeMetaTransform(avInfo)
|
||||
if artist, ok := avInfo.Format.Tags["artist"]; ok {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.Artist,
|
||||
Value: artist,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
if album, ok := avInfo.Format.Tags["album"]; ok {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.MusicAlbum,
|
||||
Value: album,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
if title, ok := avInfo.Format.Tags["title"]; ok {
|
||||
metas = append(metas, driver.MediaMeta{
|
||||
Key: mediameta.MusicTitle,
|
||||
Value: title,
|
||||
Type: driver.MediaTypeMusic,
|
||||
})
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) {
|
||||
resp, err := handler.extractMediaInfo(ctx, path, exifParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var imageInfo ImageInfo
|
||||
if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal media info: %w", err)
|
||||
}
|
||||
|
||||
metas := make([]driver.MediaMeta, 0)
|
||||
exifMap := lo.MapEntries(imageInfo, func(key string, value ImageProp) (string, string) {
|
||||
return key, value.Value
|
||||
})
|
||||
metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...)
|
||||
metas = append(metas, parseGpsInfo(imageInfo)...)
|
||||
for i := 0; i < len(metas); i++ {
|
||||
metas[i].Type = driver.MetaTypeExif
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) extractMediaInfo(ctx context.Context, path string, param string) (string, error) {
|
||||
mediaInfoExpire := time.Now().Add(mediaInfoTTL)
|
||||
ediaInfoUrl := handler.signSourceURL(fmt.Sprintf("%s?%s", path, param), &mediaInfoExpire)
|
||||
resp, err := handler.httpClient.
|
||||
Request(http.MethodGet, ediaInfoUrl, nil, request.WithContext(ctx)).
|
||||
CheckHTTPResponse(http.StatusOK).
|
||||
GetResponseIgnoreErr()
|
||||
if err != nil {
|
||||
return "", unmarshalError(resp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func unmarshalError(resp string, originErr error) error {
|
||||
if resp == "" {
|
||||
return originErr
|
||||
}
|
||||
|
||||
var err QiniuMediaError
|
||||
if err := json.Unmarshal([]byte(resp), &err); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal qiniu error: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("qiniu error: %s", err.Error)
|
||||
}
|
||||
|
||||
func parseGpsInfo(imageInfo ImageInfo) []driver.MediaMeta {
|
||||
latitude := imageInfo["GPSLatitude"] // 31, 16.2680820, 0
|
||||
longitude := imageInfo["GPSLongitude"] // 120, 42.9103939, 0
|
||||
latRef := imageInfo["GPSLatitudeRef"] // N
|
||||
lonRef := imageInfo["GPSLongitudeRef"] // E
|
||||
|
||||
// Make sure all value exist in map
|
||||
if latitude.Value == "" || longitude.Value == "" || latRef.Value == "" || lonRef.Value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lat := parseRawGPS(latitude.Value, latRef.Value)
|
||||
lon := parseRawGPS(longitude.Value, lonRef.Value)
|
||||
if !math.IsNaN(lat) && !math.IsNaN(lon) {
|
||||
lat, lng := mediameta.NormalizeGPS(lat, lon)
|
||||
return []driver.MediaMeta{{
|
||||
Key: mediameta.GpsLat,
|
||||
Value: fmt.Sprintf("%f", lat),
|
||||
}, {
|
||||
Key: mediameta.GpsLng,
|
||||
Value: fmt.Sprintf("%f", lng),
|
||||
}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRawGPS(gpsStr string, ref string) float64 {
|
||||
elem := strings.Split(gpsStr, ", ")
|
||||
if len(elem) < 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var (
|
||||
deg float64
|
||||
minutes float64
|
||||
seconds float64
|
||||
)
|
||||
|
||||
deg, _ = strconv.ParseFloat(elem[0], 64)
|
||||
if len(elem) >= 2 {
|
||||
minutes, _ = strconv.ParseFloat(elem[1], 64)
|
||||
}
|
||||
if len(elem) >= 3 {
|
||||
seconds, _ = strconv.ParseFloat(elem[2], 64)
|
||||
}
|
||||
|
||||
decimal := deg + minutes/60.0 + seconds/3600.0
|
||||
|
||||
if ref == "S" || ref == "W" {
|
||||
return -decimal
|
||||
}
|
||||
|
||||
return decimal
|
||||
}
|
||||
428
pkg/filemanager/driver/qiniu/qiniu.go
Normal file
428
pkg/filemanager/driver/qiniu/qiniu.go
Normal file
@@ -0,0 +1,428 @@
|
||||
package qiniu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
maxDeleteBatch = 1000
|
||||
trafficLimitParam = "X-Qiniu-Traffic-Limit"
|
||||
)
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
// Driver 本地策略适配器
|
||||
type Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
|
||||
mac *qbox.Mac
|
||||
cfg *storage.Config
|
||||
bucket *storage.BucketManager
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
mime mime.MimeDetector
|
||||
httpClient request.Client
|
||||
|
||||
chunkSize int64
|
||||
}
|
||||
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) {
|
||||
chunkSize := policy.Settings.ChunkSize
|
||||
if policy.Settings.ChunkSize == 0 {
|
||||
chunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
mac := qbox.NewMac(policy.AccessKey, policy.SecretKey)
|
||||
cfg := &storage.Config{UseHTTPS: true}
|
||||
|
||||
driver := &Driver{
|
||||
policy: policy,
|
||||
settings: settings,
|
||||
chunkSize: chunkSize,
|
||||
config: config,
|
||||
l: l,
|
||||
mime: mime,
|
||||
mac: mac,
|
||||
cfg: cfg,
|
||||
bucket: storage.NewBucketManager(mac, cfg),
|
||||
httpClient: request.NewClient(config, request.WithLogger(l)),
|
||||
}
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
//
|
||||
//// List 列出给定路径下的文件
|
||||
//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// base = strings.TrimPrefix(base, "/")
|
||||
// if base != "" {
|
||||
// base += "/"
|
||||
// }
|
||||
//
|
||||
// var (
|
||||
// delimiter string
|
||||
// marker string
|
||||
// objects []storage.ListItem
|
||||
// commons []string
|
||||
// )
|
||||
// if !recursive {
|
||||
// delimiter = "/"
|
||||
// }
|
||||
//
|
||||
// for {
|
||||
// entries, folders, nextMarker, hashNext, err := handler.bucket.ListFiles(
|
||||
// handler.policy.BucketName,
|
||||
// base, delimiter, marker, 1000)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// objects = append(objects, entries...)
|
||||
// commons = append(commons, folders...)
|
||||
// if !hashNext {
|
||||
// break
|
||||
// }
|
||||
// marker = nextMarker
|
||||
// }
|
||||
//
|
||||
// // 处理列取结果
|
||||
// res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
// // 处理目录
|
||||
// for _, object := range commons {
|
||||
// rel, err := filepath.Rel(base, object)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object),
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: 0,
|
||||
// IsDir: true,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
// // 处理文件
|
||||
// for _, object := range objects {
|
||||
// rel, err := filepath.Rel(base, object.Key)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object.Key),
|
||||
// Source: object.Key,
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: uint64(object.Fsize),
|
||||
// IsDir: false,
|
||||
// LastModify: time.Unix(object.PutTime/10000000, 0),
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
// 凭证有效期
|
||||
credentialTTL := handler.settings.UploadSessionTTL(ctx)
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
|
||||
// 生成上传策略
|
||||
scope := handler.policy.BucketName
|
||||
if overwrite {
|
||||
scope = fmt.Sprintf("%s:%s", handler.policy.BucketName, file.Props.SavePath)
|
||||
}
|
||||
putPolicy := storage.PutPolicy{
|
||||
// 指定为覆盖策略
|
||||
Scope: scope,
|
||||
SaveKey: file.Props.SavePath,
|
||||
ForceSaveKey: true,
|
||||
FsizeLimit: file.Props.Size,
|
||||
Expires: uint64(time.Now().Add(credentialTTL).Unix()),
|
||||
}
|
||||
upToken := putPolicy.UploadToken(handler.mac)
|
||||
|
||||
// 初始化分片上传
|
||||
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
|
||||
upHost, err := resumeUploader.UpHost(handler.policy.AccessKey, handler.policy.BucketName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get upload host: %w", err)
|
||||
}
|
||||
|
||||
ret := &storage.InitPartsRet{}
|
||||
err = resumeUploader.InitParts(ctx, upToken, upHost, handler.policy.BucketName, file.Props.SavePath, true, ret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initiate multipart upload: %w", err)
|
||||
}
|
||||
|
||||
chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{
|
||||
Max: handler.settings.ChunkRetryLimit(ctx),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, handler.settings.UseChunkBuffer(ctx), handler.l, handler.settings.TempPath(ctx))
|
||||
|
||||
parts := make([]*storage.UploadPartsRet, 0, chunks.Num())
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
partRet := &storage.UploadPartsRet{}
|
||||
err := resumeUploader.UploadParts(
|
||||
ctx, upToken, upHost, handler.policy.BucketName, file.Props.SavePath, true, ret.UploadID,
|
||||
int64(current.Index()+1), "", partRet, content, int(current.Length()))
|
||||
if err == nil {
|
||||
parts = append(parts, partRet)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
_ = handler.cancelUpload(upHost, file.Props.SavePath, ret.UploadID, upToken)
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
err = resumeUploader.CompleteParts(ctx, upToken, upHost, nil, handler.policy.BucketName,
|
||||
file.Props.SavePath, true, ret.UploadID, &storage.RputV2Extra{
|
||||
MimeType: mimeType,
|
||||
Progresses: lo.Map(parts, func(part *storage.UploadPartsRet, i int) storage.UploadPartInfo {
|
||||
return storage.UploadPartInfo{
|
||||
Etag: part.Etag,
|
||||
PartNumber: int64(i) + 1,
|
||||
}
|
||||
}),
|
||||
})
|
||||
if err != nil {
|
||||
_ = handler.cancelUpload(upHost, file.Props.SavePath, ret.UploadID, upToken)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
groups := lo.Chunk(files, maxDeleteBatch)
|
||||
failed := make([]string, 0)
|
||||
var lastError error
|
||||
|
||||
for index, group := range groups {
|
||||
handler.l.Debug("Process delete group #%d: %v", index, group)
|
||||
// 删除文件
|
||||
rets, err := handler.bucket.BatchWithContext(ctx, handler.policy.BucketName, lo.Map(group, func(key string, index int) string {
|
||||
return storage.URIDelete(handler.policy.BucketName, key)
|
||||
}))
|
||||
|
||||
// 处理删除结果
|
||||
if err != nil {
|
||||
for k, ret := range rets {
|
||||
if ret.Code != 200 && ret.Code != 612 {
|
||||
failed = append(failed, group[k])
|
||||
lastError = err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(failed) > 0 && lastError == nil {
|
||||
lastError = fmt.Errorf("failed to delete files: %v", failed)
|
||||
}
|
||||
|
||||
return failed, lastError
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
w, h := handler.settings.ThumbSize(ctx)
|
||||
|
||||
thumb := fmt.Sprintf("%s?imageView2/1/w/%d/h/%d", e.Source(), w, h)
|
||||
return handler.signSourceURL(
|
||||
thumb,
|
||||
expire,
|
||||
), nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
path := e.Source()
|
||||
|
||||
query := url.Values{}
|
||||
|
||||
// 加入下载相关设置
|
||||
if args.IsDownload {
|
||||
query.Add("attname", args.DisplayName)
|
||||
}
|
||||
|
||||
if args.Speed > 0 {
|
||||
// Byte 转换为 bit
|
||||
args.Speed *= 8
|
||||
|
||||
// Qiniu 对速度值有范围限制
|
||||
if args.Speed < 819200 {
|
||||
args.Speed = 819200
|
||||
}
|
||||
if args.Speed > 838860800 {
|
||||
args.Speed = 838860800
|
||||
}
|
||||
query.Add(trafficLimitParam, fmt.Sprintf("%d", args.Speed))
|
||||
}
|
||||
|
||||
if len(query) > 0 {
|
||||
path = path + "?" + query.Encode()
|
||||
}
|
||||
|
||||
// 取得原始文件地址
|
||||
return handler.signSourceURL(path, args.Expire), nil
|
||||
}
|
||||
|
||||
func (handler *Driver) signSourceURL(path string, expire *time.Time) string {
|
||||
var sourceURL string
|
||||
if handler.policy.IsPrivate {
|
||||
deadline := time.Now().Add(time.Duration(24) * time.Hour * 365 * 20).Unix()
|
||||
if expire != nil {
|
||||
deadline = expire.Unix()
|
||||
}
|
||||
sourceURL = storage.MakePrivateURL(handler.mac, handler.policy.Settings.ProxyServer, path, deadline)
|
||||
} else {
|
||||
sourceURL = storage.MakePublicURL(handler.policy.Settings.ProxyServer, path)
|
||||
}
|
||||
return sourceURL
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
// 生成回调地址
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
apiUrl := routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeQiniu, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
|
||||
// 创建上传策略
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: fmt.Sprintf("%s:%s", handler.policy.BucketName, file.Props.SavePath),
|
||||
CallbackURL: apiUrl,
|
||||
CallbackBody: `{"size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`,
|
||||
CallbackBodyType: "application/json",
|
||||
SaveKey: file.Props.SavePath,
|
||||
ForceSaveKey: true,
|
||||
FsizeLimit: file.Props.Size,
|
||||
Expires: uint64(file.Props.ExpireAt.Unix()),
|
||||
}
|
||||
|
||||
// 初始化分片上传
|
||||
upToken := putPolicy.UploadToken(handler.mac)
|
||||
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
|
||||
upHost, err := resumeUploader.UpHost(handler.policy.AccessKey, handler.policy.BucketName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get upload host: %w", err)
|
||||
}
|
||||
|
||||
ret := &storage.InitPartsRet{}
|
||||
err = resumeUploader.InitParts(ctx, upToken, upHost, handler.policy.BucketName, file.Props.SavePath, true, ret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initiate multipart upload: %w", err)
|
||||
}
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
uploadSession.UploadID = ret.UploadID
|
||||
return &fs.UploadCredential{
|
||||
UploadID: ret.UploadID,
|
||||
UploadURLs: []string{getUploadUrl(upHost, handler.policy.BucketName, file.Props.SavePath, ret.UploadID)},
|
||||
Credential: upToken,
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: handler.chunkSize,
|
||||
MimeType: mimeType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
|
||||
return resumeUploader.Client.CallWith(ctx, nil, "DELETE", uploadSession.UploadURL, http.Header{"Authorization": {"UpToken " + uploadSession.Credential}}, nil, 0)
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
mediaMetaExts := handler.policy.Settings.MediaMetaExts
|
||||
if !handler.policy.Settings.NativeMediaProcessing {
|
||||
mediaMetaExts = nil
|
||||
}
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaSupportedExts: mediaMetaExts,
|
||||
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbSupportedExts: handler.policy.Settings.ThumbExts,
|
||||
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
|
||||
ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts,
|
||||
ThumbMaxSize: handler.policy.Settings.ThumbMaxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
if util.ContainsString(supportedImageExt, ext) {
|
||||
return handler.extractImageMeta(ctx, path)
|
||||
}
|
||||
|
||||
return handler.extractAvMeta(ctx, path)
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (handler *Driver) cancelUpload(upHost, savePath, uploadId, upToken string) error {
|
||||
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
|
||||
uploadUrl := getUploadUrl(upHost, handler.policy.BucketName, savePath, uploadId)
|
||||
err := resumeUploader.Client.CallWith(context.Background(), nil, "DELETE", uploadUrl, http.Header{"Authorization": {"UpToken " + upToken}}, nil, 0)
|
||||
if err != nil {
|
||||
handler.l.Error("Failed to cancel upload session for %q: %s", savePath, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getUploadUrl(upHost, bucket, key, uploadId string) string {
|
||||
return upHost + "/buckets/" + bucket + "/objects/" + base64.URLEncoding.EncodeToString([]byte(key)) + "/uploads/" + uploadId
|
||||
}
|
||||
266
pkg/filemanager/driver/remote/client.go
Normal file
266
pkg/filemanager/driver/remote/client.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/gofrs/uuid"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
OverwriteHeader = constants.CrHeaderPrefix + "Overwrite"
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
)
|
||||
|
||||
// Client to operate uploading to remote slave server
|
||||
type Client interface {
|
||||
// CreateUploadSession creates remote upload session
|
||||
CreateUploadSession(ctx context.Context, session *fs.UploadSession, overwrite bool) error
|
||||
// GetUploadURL signs an url for uploading file
|
||||
GetUploadURL(ctx context.Context, expires time.Time, sessionID string) (string, string, error)
|
||||
// Upload uploads file to remote server
|
||||
Upload(ctx context.Context, file *fs.UploadRequest) error
|
||||
// DeleteUploadSession deletes remote upload session
|
||||
DeleteUploadSession(ctx context.Context, sessionID string) error
|
||||
// MediaMeta gets media meta from remote server
|
||||
MediaMeta(ctx context.Context, src, ext string) ([]driver.MediaMeta, error)
|
||||
// DeleteFiles deletes files from remote server
|
||||
DeleteFiles(ctx context.Context, files ...string) ([]string, error)
|
||||
}
|
||||
|
||||
type DeleteFileRequest struct {
|
||||
Files []string `json:"files"`
|
||||
}
|
||||
|
||||
// NewClient creates new Client from given policy
|
||||
func NewClient(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider, config conf.ConfigProvider, l logging.Logger) (Client, error) {
|
||||
if policy.Edges.Node == nil {
|
||||
return nil, fmt.Errorf("remote storage policy %d has no node", policy.ID)
|
||||
}
|
||||
|
||||
authInstance := auth.HMACAuth{[]byte(policy.Edges.Node.SlaveKey)}
|
||||
serverURL, err := url.Parse(policy.Edges.Node.Server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base, _ := url.Parse(constants.APIPrefixSlave)
|
||||
|
||||
return &remoteClient{
|
||||
policy: policy,
|
||||
authInstance: authInstance,
|
||||
httpClient: request.NewClient(
|
||||
config,
|
||||
request.WithEndpoint(serverURL.ResolveReference(base).String()),
|
||||
request.WithCredential(authInstance, int64(settings.SlaveRequestSignTTL(ctx))),
|
||||
request.WithSlaveMeta(policy.Edges.Node.ID),
|
||||
request.WithMasterMeta(settings.SiteBasic(ctx).ID, settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()),
|
||||
request.WithCorrelationID(),
|
||||
),
|
||||
settings: settings,
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type remoteClient struct {
|
||||
policy *ent.StoragePolicy
|
||||
authInstance auth.Auth
|
||||
httpClient request.Client
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
}
|
||||
|
||||
func (c *remoteClient) Upload(ctx context.Context, file *fs.UploadRequest) error {
|
||||
ttl := c.settings.UploadSessionTTL(ctx)
|
||||
session := &fs.UploadSession{
|
||||
Props: file.Props.Copy(),
|
||||
Policy: c.policy,
|
||||
}
|
||||
session.Props.UploadSessionID = uuid.Must(uuid.NewV4()).String()
|
||||
session.Props.ExpireAt = time.Now().Add(ttl)
|
||||
|
||||
// Create upload session
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
if err := c.CreateUploadSession(ctx, session, overwrite); err != nil {
|
||||
return fmt.Errorf("failed to create upload session: %w", err)
|
||||
}
|
||||
|
||||
// Initial chunk groups
|
||||
chunks := chunk.NewChunkGroup(file, c.policy.Settings.ChunkSize, &backoff.ConstantBackoff{
|
||||
Max: c.settings.ChunkRetryLimit(ctx),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, c.settings.UseChunkBuffer(ctx), c.l, c.settings.TempPath(ctx))
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
return c.uploadChunk(ctx, session.Props.UploadSessionID, current.Index(), content, overwrite, current.Length())
|
||||
}
|
||||
|
||||
// upload chunks
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
if err := c.DeleteUploadSession(ctx, session.Props.UploadSessionID); err != nil {
|
||||
c.l.Warning("failed to delete upload session: %s", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error {
|
||||
resp, err := c.httpClient.Request(
|
||||
"DELETE",
|
||||
"upload/"+sessionID,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) DeleteFiles(ctx context.Context, files ...string) ([]string, error) {
|
||||
req := &DeleteFileRequest{
|
||||
Files: files,
|
||||
}
|
||||
|
||||
reqStr, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return files, fmt.Errorf("failed to marshal delete request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Request(
|
||||
"DELETE",
|
||||
"file",
|
||||
bytes.NewReader(reqStr),
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(logging.FromContext(ctx)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
var failed []string
|
||||
failed = files
|
||||
if resp.Code == serializer.CodeNotFullySuccess {
|
||||
resp.GobDecode(&failed)
|
||||
}
|
||||
return failed, fmt.Errorf(resp.Error)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) MediaMeta(ctx context.Context, src, ext string) ([]driver.MediaMeta, error) {
|
||||
resp, err := c.httpClient.Request(
|
||||
http.MethodGet,
|
||||
routes.SlaveMediaMetaRoute(src, ext),
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(c.l),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return nil, fmt.Errorf(resp.Error)
|
||||
}
|
||||
|
||||
var metas []driver.MediaMeta
|
||||
resp.GobDecode(&metas)
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) CreateUploadSession(ctx context.Context, session *fs.UploadSession, overwrite bool) error {
|
||||
reqBodyEncoded, err := json.Marshal(map[string]interface{}{
|
||||
"session": session,
|
||||
"overwrite": overwrite,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
resp, err := c.httpClient.Request(
|
||||
"PUT",
|
||||
"upload",
|
||||
bodyReader,
|
||||
request.WithContext(ctx),
|
||||
request.WithLogger(c.l),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) GetUploadURL(ctx context.Context, expires time.Time, sessionID string) (string, string, error) {
|
||||
base, err := url.Parse(c.policy.Edges.Node.Server)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", routes.SlaveUploadUrl(base, sessionID).String(), nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
req = auth.SignRequest(ctx, c.authInstance, req, &expires)
|
||||
return req.URL.String(), req.Header["Authorization"][0], nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error {
|
||||
resp, err := c.httpClient.Request(
|
||||
"POST",
|
||||
fmt.Sprintf("upload/%s?chunk=%d", sessionID, index),
|
||||
chunk,
|
||||
request.WithContext(ctx),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
request.WithContentLength(size),
|
||||
request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
273
pkg/filemanager/driver/remote/remote.go
Normal file
273
pkg/filemanager/driver/remote/remote.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
// Driver 远程存储策略适配器
|
||||
type Driver struct {
|
||||
Client request.Client
|
||||
Policy *ent.StoragePolicy
|
||||
AuthInstance auth.Auth
|
||||
|
||||
uploadClient Client
|
||||
config conf.ConfigProvider
|
||||
settings setting.Provider
|
||||
}
|
||||
|
||||
// New initializes a new Driver from policy
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger) (*Driver, error) {
|
||||
client, err := NewClient(ctx, policy, settings, config, l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Driver{
|
||||
Policy: policy,
|
||||
Client: request.NewClient(config),
|
||||
AuthInstance: auth.HMACAuth{[]byte(policy.Edges.Node.SlaveKey)},
|
||||
uploadClient: client,
|
||||
settings: settings,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//// List 列取文件
|
||||
//func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
// var res []response.Object
|
||||
//
|
||||
// reqBody := serializer.ListRequest{
|
||||
// Path: path,
|
||||
// Recursive: recursive,
|
||||
// }
|
||||
// reqBodyEncoded, err := json.Marshal(reqBody)
|
||||
// if err != nil {
|
||||
// return res, err
|
||||
// }
|
||||
//
|
||||
// // 发送列表请求
|
||||
// bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
// signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
// resp, err := handler.Client.Request(
|
||||
// "POST",
|
||||
// handler.getAPIUrl("list"),
|
||||
// bodyReader,
|
||||
// request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
||||
// request.WithMasterMeta(handler.settings.SiteBasic(ctx).ID, handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx)).String()),
|
||||
// ).CheckHTTPResponse(200).DecodeResponse()
|
||||
// if err != nil {
|
||||
// return res, err
|
||||
// }
|
||||
//
|
||||
// // 处理列取结果
|
||||
// if resp.Code != 0 {
|
||||
// return res, errors.New(resp.Error)
|
||||
// }
|
||||
//
|
||||
// if resStr, ok := resp.Data.(string); ok {
|
||||
// err = json.Unmarshal([]byte(resStr), &res)
|
||||
// if err != nil {
|
||||
// return res, err
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//}
|
||||
|
||||
// getAPIUrl 获取接口请求地址
|
||||
func (handler *Driver) getAPIUrl(scope string, routes ...string) string {
|
||||
serverURL, err := url.Parse(handler.Policy.Edges.Node.Server)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var controller *url.URL
|
||||
|
||||
switch scope {
|
||||
case "delete":
|
||||
controller, _ = url.Parse("/api/v3/slave/delete")
|
||||
case "thumb":
|
||||
controller, _ = url.Parse("/api/v3/slave/thumb")
|
||||
case "list":
|
||||
controller, _ = url.Parse("/api/v3/slave/list")
|
||||
default:
|
||||
controller = serverURL
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
controller.Path = path.Join(controller.Path, r)
|
||||
}
|
||||
|
||||
return serverURL.ResolveReference(controller).String()
|
||||
}
|
||||
|
||||
// Open 获取文件内容
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
//// 尝试获取速度限制
|
||||
//speedLimit := 0
|
||||
//if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
|
||||
// speedLimit = user.Group.SpeedLimit
|
||||
//}
|
||||
//
|
||||
//// 获取文件源地址
|
||||
//downloadURL, err := handler.Source(ctx, path, nil, true, int64(speedLimit))
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
//
|
||||
//// 获取文件数据流
|
||||
//resp, err := handler.Client.Request(
|
||||
// "GET",
|
||||
// downloadURL,
|
||||
// nil,
|
||||
// request.WithContext(ctx),
|
||||
// request.WithTimeout(time.Duration(0)),
|
||||
// request.WithMasterMeta(handler.settings.SiteBasic(ctx).ID, handler.settings.SiteURL(ctx).String()),
|
||||
//).CheckHTTPResponse(200).GetRSCloser()
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
//
|
||||
//resp.SetFirstFakeChunk()
|
||||
//
|
||||
//// 尝试获取文件大小
|
||||
//if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
// resp.SetContentLength(int64(file.Size))
|
||||
//}
|
||||
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
return handler.uploadClient.Upload(ctx, file)
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
failed, err := handler.uploadClient.DeleteFiles(ctx, files...)
|
||||
if err != nil {
|
||||
return failed, err
|
||||
}
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
serverURL, err := url.Parse(handler.Policy.Edges.Node.Server)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse server url failed: %w", err)
|
||||
}
|
||||
|
||||
thumbURL := routes.SlaveThumbUrl(serverURL, e.Source(), ext)
|
||||
signedThumbURL, err := auth.SignURI(ctx, handler.AuthInstance, thumbURL.String(), expire)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedThumbURL.String(), nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
server, err := url.Parse(handler.Policy.Edges.Node.Server)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nodeId := 0
|
||||
if handler.config.System().Mode == conf.SlaveMode {
|
||||
nodeId = handler.Policy.NodeID
|
||||
}
|
||||
|
||||
base := routes.SlaveFileContentUrl(
|
||||
server,
|
||||
e.Source(),
|
||||
args.DisplayName,
|
||||
args.IsDownload,
|
||||
args.Speed,
|
||||
nodeId,
|
||||
)
|
||||
internalProxyed, err := auth.SignURI(ctx, handler.AuthInstance, base.String(), args.Expire)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign internal slave content URL: %w", err)
|
||||
}
|
||||
|
||||
return internalProxyed.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
// 在从机端创建上传会话
|
||||
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeRemote, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取上传地址
|
||||
uploadURL, sign, err := handler.uploadClient.GetUploadURL(ctx, uploadSession.Props.ExpireAt, uploadSession.Props.UploadSessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign upload url: %w", err)
|
||||
}
|
||||
|
||||
return &fs.UploadCredential{
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: handler.Policy.Settings.ChunkSize,
|
||||
UploadURLs: []string{uploadURL},
|
||||
Credential: sign,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Props.UploadSessionID)
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaSupportedExts: handler.Policy.Settings.MediaMetaExts,
|
||||
MediaMetaProxy: handler.Policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbSupportedExts: handler.Policy.Settings.ThumbExts,
|
||||
ThumbProxy: handler.Policy.Settings.ThumbGeneratorProxy,
|
||||
ThumbMaxSize: handler.Policy.Settings.ThumbMaxSize,
|
||||
ThumbSupportAllExts: handler.Policy.Settings.ThumbSupportAllExts,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
return handler.uploadClient.MediaMeta(ctx, path, ext)
|
||||
}
|
||||
514
pkg/filemanager/driver/s3/s3.go
Normal file
514
pkg/filemanager/driver/s3/s3.go
Normal file
@@ -0,0 +1,514 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
// Driver S3 compatible driver
|
||||
type Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
chunkSize int64
|
||||
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
mime mime.MimeDetector
|
||||
|
||||
sess *session.Session
|
||||
svc *s3.S3
|
||||
}
|
||||
|
||||
// UploadPolicy S3上传策略
|
||||
type UploadPolicy struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions []interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
// MetaData 文件信息
|
||||
type MetaData struct {
|
||||
Size int64
|
||||
Etag string
|
||||
}
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
boolset.Sets(map[driver.HandlerCapability]bool{
|
||||
driver.HandlerCapabilityUploadSentinelRequired: true,
|
||||
}, features)
|
||||
}
|
||||
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) {
|
||||
chunkSize := policy.Settings.ChunkSize
|
||||
if policy.Settings.ChunkSize == 0 {
|
||||
chunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
driver := &Driver{
|
||||
policy: policy,
|
||||
settings: settings,
|
||||
chunkSize: chunkSize,
|
||||
config: config,
|
||||
l: l,
|
||||
mime: mime,
|
||||
}
|
||||
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Credentials: credentials.NewStaticCredentials(policy.AccessKey, policy.SecretKey, ""),
|
||||
Endpoint: &policy.Server,
|
||||
Region: &policy.Settings.Region,
|
||||
S3ForcePathStyle: &policy.Settings.S3ForcePathStyle,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
driver.sess = sess
|
||||
driver.svc = s3.New(sess)
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
//// List 列出给定路径下的文件
|
||||
//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// // 初始化列目录参数
|
||||
// base = strings.TrimPrefix(base, "/")
|
||||
// if base != "" {
|
||||
// base += "/"
|
||||
// }
|
||||
//
|
||||
// opt := &s3.ListObjectsInput{
|
||||
// Bucket: &handler.policy.BucketName,
|
||||
// Prefix: &base,
|
||||
// MaxKeys: aws.Int64(1000),
|
||||
// }
|
||||
//
|
||||
// // 是否为递归列出
|
||||
// if !recursive {
|
||||
// opt.Delimiter = aws.String("/")
|
||||
// }
|
||||
//
|
||||
// var (
|
||||
// objects []*s3.Object
|
||||
// commons []*s3.CommonPrefix
|
||||
// )
|
||||
//
|
||||
// for {
|
||||
// res, err := handler.svc.ListObjectsWithContext(ctx, opt)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// objects = append(objects, res.Contents...)
|
||||
// commons = append(commons, res.CommonPrefixes...)
|
||||
//
|
||||
// // 如果本次未列取完,则继续使用marker获取结果
|
||||
// if *res.IsTruncated {
|
||||
// opt.Marker = res.NextMarker
|
||||
// } else {
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 处理列取结果
|
||||
// res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
//
|
||||
// // 处理目录
|
||||
// for _, object := range commons {
|
||||
// rel, err := filepath.Rel(*opt.Prefix, *object.Prefix)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(*object.Prefix),
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: 0,
|
||||
// IsDir: true,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
// // 处理文件
|
||||
// for _, object := range objects {
|
||||
// rel, err := filepath.Rel(*opt.Prefix, *object.Key)
|
||||
// if err != nil {
|
||||
// continue
|
||||
// }
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(*object.Key),
|
||||
// Source: *object.Key,
|
||||
// RelativePath: filepath.ToSlash(rel),
|
||||
// Size: uint64(*object.Size),
|
||||
// IsDir: false,
|
||||
// LastModify: time.Now(),
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//
|
||||
//}
|
||||
|
||||
// Open 打开文件
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
if !overwrite {
|
||||
// Check for duplicated file
|
||||
if _, err := handler.Meta(ctx, file.Props.SavePath); err == nil {
|
||||
return fs.ErrFileExisted
|
||||
}
|
||||
}
|
||||
|
||||
uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) {
|
||||
u.PartSize = handler.chunkSize
|
||||
})
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
_, err := uploader.UploadWithContext(ctx, &s3manager.UploadInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &file.Props.SavePath,
|
||||
Body: io.LimitReader(file, file.Props.Size),
|
||||
ContentType: aws.String(mimeType),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
failed := make([]string, 0, len(files))
|
||||
batchSize := handler.policy.Settings.S3DeleteBatchSize
|
||||
if batchSize == 0 {
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html
|
||||
// The request can contain a list of up to 1000 keys that you want to delete.
|
||||
batchSize = 1000
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
|
||||
groups := lo.Chunk(files, batchSize)
|
||||
for _, group := range groups {
|
||||
if len(group) == 1 {
|
||||
// Invoke single file delete API
|
||||
_, err := handler.svc.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &group[0],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
// Ignore NoSuchKey error
|
||||
if aerr.Code() == s3.ErrCodeNoSuchKey {
|
||||
continue
|
||||
}
|
||||
}
|
||||
failed = append(failed, group[0])
|
||||
lastErr = err
|
||||
}
|
||||
} else {
|
||||
// Invoke batch delete API
|
||||
res, err := handler.svc.DeleteObjects(
|
||||
&s3.DeleteObjectsInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Delete: &s3.Delete{
|
||||
Objects: lo.Map(group, func(s string, i int) *s3.ObjectIdentifier {
|
||||
return &s3.ObjectIdentifier{Key: &s}
|
||||
}),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
failed = append(failed, group...)
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range res.Errors {
|
||||
handler.l.Debug("Failed to delete file: %s, Code:%s, Message:%s", v.Key, v.Code, v.Key)
|
||||
failed = append(failed, *v.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return failed, lastErr
|
||||
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
var contentDescription *string
|
||||
if args.IsDownload {
|
||||
encodedFilename := url.PathEscape(args.DisplayName)
|
||||
contentDescription = aws.String(fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`,
|
||||
encodedFilename, encodedFilename))
|
||||
}
|
||||
|
||||
req, _ := handler.svc.GetObjectRequest(
|
||||
&s3.GetObjectInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: aws.String(e.Source()),
|
||||
ResponseContentDisposition: contentDescription,
|
||||
})
|
||||
|
||||
ttl := time.Duration(604800) * time.Second // 7 days
|
||||
if args.Expire != nil {
|
||||
ttl = time.Until(*args.Expire)
|
||||
}
|
||||
signedURL, err := req.Presign(ttl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将最终生成的签名URL域名换成用户自定义的加速域名(如果有)
|
||||
finalURL, err := url.Parse(signedURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 公有空间替换掉Key及不支持的头
|
||||
if !handler.policy.IsPrivate {
|
||||
finalURL.RawQuery = ""
|
||||
}
|
||||
|
||||
return finalURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
// Check for duplicated file
|
||||
if _, err := handler.Meta(ctx, file.Props.SavePath); err == nil {
|
||||
return nil, fs.ErrFileExisted
|
||||
}
|
||||
|
||||
// 生成回调地址
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
// 在从机端创建上传会话
|
||||
uploadSession.ChunkSize = handler.chunkSize
|
||||
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeS3, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
// 创建分片上传
|
||||
res, err := handler.svc.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &uploadSession.Props.SavePath,
|
||||
Expires: &uploadSession.Props.ExpireAt,
|
||||
ContentType: aws.String(mimeType),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create multipart upload: %w", err)
|
||||
}
|
||||
|
||||
uploadSession.UploadID = *res.UploadId
|
||||
|
||||
// 为每个分片签名上传 URL
|
||||
chunks := chunk.NewChunkGroup(file, handler.chunkSize, &backoff.ConstantBackoff{}, false, handler.l, "")
|
||||
urls := make([]string, chunks.Num())
|
||||
for chunks.Next() {
|
||||
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
|
||||
signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &uploadSession.Props.SavePath,
|
||||
PartNumber: aws.Int64(int64(c.Index() + 1)),
|
||||
ContentLength: aws.Int64(c.Length()),
|
||||
UploadId: res.UploadId,
|
||||
})
|
||||
|
||||
signedURL, err := signedReq.Presign(time.Until(uploadSession.Props.ExpireAt))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
urls[c.Index()] = signedURL
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 签名完成分片上传的请求URL
|
||||
signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &file.Props.SavePath,
|
||||
UploadId: res.UploadId,
|
||||
})
|
||||
|
||||
signedURL, err := signedReq.Presign(time.Until(uploadSession.Props.ExpireAt))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成上传凭证
|
||||
return &fs.UploadCredential{
|
||||
UploadID: *res.UploadId,
|
||||
UploadURLs: urls,
|
||||
CompleteURL: signedURL,
|
||||
SessionID: uploadSession.Props.UploadSessionID,
|
||||
ChunkSize: handler.chunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Meta 获取文件信息
|
||||
func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
|
||||
res, err := handler.svc.HeadObjectWithContext(ctx,
|
||||
&s3.HeadObjectInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &path,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MetaData{
|
||||
Size: *res.ContentLength,
|
||||
Etag: *res.ETag,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (handler *Driver) CORS() error {
|
||||
rule := s3.CORSRule{
|
||||
AllowedMethods: aws.StringSlice([]string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
}),
|
||||
AllowedOrigins: aws.StringSlice([]string{"*"}),
|
||||
AllowedHeaders: aws.StringSlice([]string{"*"}),
|
||||
ExposeHeaders: aws.StringSlice([]string{"ETag"}),
|
||||
MaxAgeSeconds: aws.Int64(3600),
|
||||
}
|
||||
|
||||
_, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
CORSConfiguration: &s3.CORSConfiguration{
|
||||
CORSRules: []*s3.CORSRule{&rule},
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
_, err := handler.svc.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{
|
||||
UploadId: &uploadSession.UploadID,
|
||||
Bucket: &handler.policy.BucketName,
|
||||
Key: &uploadSession.Props.SavePath,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler *Driver) cancelUpload(key, id *string) {
|
||||
if _, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
|
||||
Bucket: &handler.policy.BucketName,
|
||||
UploadId: id,
|
||||
Key: key,
|
||||
}); err != nil {
|
||||
handler.l.Warning("failed to abort multipart upload: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
|
||||
MaxSourceExpire: time.Duration(604800) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
if session.SentinelTaskID == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure uploaded file size is correct
|
||||
res, err := handler.Meta(ctx, session.Props.SavePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get uploaded file size: %w", err)
|
||||
}
|
||||
|
||||
if res.Size != session.Props.Size {
|
||||
return serializer.NewError(
|
||||
serializer.CodeMetaMismatch,
|
||||
fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.Size),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (r Reader) Read(p []byte) (int, error) {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
154
pkg/filemanager/driver/upyun/media.go
Normal file
154
pkg/filemanager/driver/upyun/media.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package upyun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/mediameta"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/samber/lo"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
mediaInfoTTL = time.Duration(10) * time.Minute
|
||||
)
|
||||
|
||||
type (
|
||||
ImageInfo struct {
|
||||
Exif map[string]string `json:"EXIF"`
|
||||
}
|
||||
)
|
||||
|
||||
func (handler *Driver) extractImageMeta(ctx context.Context, path string) ([]driver.MediaMeta, error) {
|
||||
resp, err := handler.extractMediaInfo(ctx, path, "!/meta")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println(resp)
|
||||
|
||||
var imageInfo ImageInfo
|
||||
if err := json.Unmarshal([]byte(resp), &imageInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal image info: %w", err)
|
||||
}
|
||||
|
||||
metas := make([]driver.MediaMeta, 0, len(imageInfo.Exif))
|
||||
exifMap := lo.MapEntries(imageInfo.Exif, func(key string, value string) (string, string) {
|
||||
switch key {
|
||||
case "0xA434":
|
||||
key = "LensModel"
|
||||
}
|
||||
return key, value
|
||||
})
|
||||
metas = append(metas, mediameta.ExtractExifMap(exifMap, time.Time{})...)
|
||||
metas = append(metas, parseGpsInfo(imageInfo.Exif)...)
|
||||
|
||||
for i := 0; i < len(metas); i++ {
|
||||
metas[i].Type = driver.MetaTypeExif
|
||||
}
|
||||
|
||||
return metas, nil
|
||||
}
|
||||
|
||||
func (handler *Driver) extractMediaInfo(ctx context.Context, path string, param string) (string, error) {
|
||||
mediaInfoExpire := time.Now().Add(mediaInfoTTL)
|
||||
mediaInfoUrl, err := handler.signURL(ctx, path+param, nil, &mediaInfoExpire)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := handler.httpClient.
|
||||
Request(http.MethodGet, mediaInfoUrl, nil, request.WithContext(ctx)).
|
||||
CheckHTTPResponse(http.StatusOK).
|
||||
GetResponseIgnoreErr()
|
||||
if err != nil {
|
||||
return "", unmarshalError(resp, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func unmarshalError(resp string, err error) error {
|
||||
return fmt.Errorf("upyun error: %s", err)
|
||||
}
|
||||
|
||||
func parseGpsInfo(imageInfo map[string]string) []driver.MediaMeta {
|
||||
latitude := imageInfo["GPSLatitude"] // 31/1, 162680820/10000000, 0/1
|
||||
longitude := imageInfo["GPSLongitude"] // 120/1, 429103939/10000000, 0/1
|
||||
latRef := imageInfo["GPSLatitudeRef"] // N
|
||||
lonRef := imageInfo["GPSLongitudeRef"] // E
|
||||
|
||||
// Make sure all value exist in map
|
||||
if latitude == "" || longitude == "" || latRef == "" || lonRef == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lat := parseRawGPS(latitude, latRef)
|
||||
lon := parseRawGPS(longitude, lonRef)
|
||||
if !math.IsNaN(lat) && !math.IsNaN(lon) {
|
||||
lat, lng := mediameta.NormalizeGPS(lat, lon)
|
||||
return []driver.MediaMeta{{
|
||||
Key: mediameta.GpsLat,
|
||||
Value: fmt.Sprintf("%f", lat),
|
||||
}, {
|
||||
Key: mediameta.GpsLng,
|
||||
Value: fmt.Sprintf("%f", lng),
|
||||
}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRawGPS(gpsStr string, ref string) float64 {
|
||||
elem := strings.Split(gpsStr, ",")
|
||||
if len(elem) < 1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var (
|
||||
deg float64
|
||||
minutes float64
|
||||
seconds float64
|
||||
)
|
||||
|
||||
deg = getGpsElemValue(elem[0])
|
||||
if len(elem) >= 2 {
|
||||
minutes = getGpsElemValue(elem[1])
|
||||
}
|
||||
if len(elem) >= 3 {
|
||||
seconds = getGpsElemValue(elem[2])
|
||||
}
|
||||
|
||||
decimal := deg + minutes/60.0 + seconds/3600.0
|
||||
|
||||
if ref == "S" || ref == "W" {
|
||||
return -decimal
|
||||
}
|
||||
|
||||
return decimal
|
||||
}
|
||||
|
||||
func getGpsElemValue(elm string) float64 {
|
||||
elements := strings.Split(elm, "/")
|
||||
if len(elements) != 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
numerator, err := strconv.ParseFloat(elements[0], 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
denominator, err := strconv.ParseFloat(elements[1], 64)
|
||||
if err != nil || denominator == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return numerator / denominator
|
||||
}
|
||||
382
pkg/filemanager/driver/upyun/upyun.go
Normal file
382
pkg/filemanager/driver/upyun/upyun.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package upyun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/upyun/go-sdk/upyun"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// UploadPolicy 又拍云上传策略
|
||||
UploadPolicy struct {
|
||||
Bucket string `json:"bucket"`
|
||||
SaveKey string `json:"save-key"`
|
||||
Expiration int64 `json:"expiration"`
|
||||
CallbackURL string `json:"notify-url"`
|
||||
ContentLength uint64 `json:"content-length"`
|
||||
ContentLengthRange string `json:"content-length-range,omitempty"`
|
||||
AllowFileType string `json:"allow-file-type,omitempty"`
|
||||
}
|
||||
// Driver 又拍云策略适配器
|
||||
Driver struct {
|
||||
policy *ent.StoragePolicy
|
||||
|
||||
up *upyun.UpYun
|
||||
settings setting.Provider
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
mime mime.MimeDetector
|
||||
httpClient request.Client
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
features = &boolset.BooleanSet{}
|
||||
)
|
||||
|
||||
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
|
||||
config conf.ConfigProvider, l logging.Logger, mime mime.MimeDetector) (*Driver, error) {
|
||||
driver := &Driver{
|
||||
policy: policy,
|
||||
settings: settings,
|
||||
config: config,
|
||||
l: l,
|
||||
mime: mime,
|
||||
httpClient: request.NewClient(config, request.WithLogger(l)),
|
||||
up: upyun.NewUpYun(&upyun.UpYunConfig{
|
||||
Bucket: policy.BucketName,
|
||||
Operator: policy.AccessKey,
|
||||
Password: policy.SecretKey,
|
||||
}),
|
||||
}
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// base = strings.TrimPrefix(base, "/")
|
||||
//
|
||||
// // 用于接受SDK返回对象的chan
|
||||
// objChan := make(chan *upyun.FileInfo)
|
||||
// objects := []*upyun.FileInfo{}
|
||||
//
|
||||
// // 列取配置
|
||||
// listConf := &upyun.GetObjectsConfig{
|
||||
// Path: "/" + base,
|
||||
// ObjectsChan: objChan,
|
||||
// MaxListTries: 1,
|
||||
// }
|
||||
// // 递归列取时不限制递归次数
|
||||
// if recursive {
|
||||
// listConf.MaxListLevel = -1
|
||||
// }
|
||||
//
|
||||
// // 启动一个goroutine收集列取的对象信
|
||||
// wg := &sync.WaitGroup{}
|
||||
// wg.Add(1)
|
||||
// go func(input chan *upyun.FileInfo, output *[]*upyun.FileInfo, wg *sync.WaitGroup) {
|
||||
// defer wg.Done()
|
||||
// for {
|
||||
// file, ok := <-input
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// *output = append(*output, file)
|
||||
// }
|
||||
// }(objChan, &objects, wg)
|
||||
//
|
||||
// up := upyun.NewUpYun(&upyun.UpYunConfig{
|
||||
// Bucket: handler.policy.BucketName,
|
||||
// Operator: handler.policy.AccessKey,
|
||||
// Password: handler.policy.SecretKey,
|
||||
// })
|
||||
//
|
||||
// err := up.List(listConf)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// wg.Wait()
|
||||
//
|
||||
// // 汇总处理列取结果
|
||||
// res := make([]response.Object, 0, len(objects))
|
||||
// for _, object := range objects {
|
||||
// res = append(res, response.Object{
|
||||
// Name: path.Base(object.Name),
|
||||
// RelativePath: object.Name,
|
||||
// Source: path.Join(base, object.Name),
|
||||
// Size: uint64(object.Size),
|
||||
// IsDir: object.IsDir,
|
||||
// LastModify: object.Time,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// return res, nil
|
||||
//}
|
||||
|
||||
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
|
||||
defer file.Close()
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := file.Mode&fs.ModeOverwrite == fs.ModeOverwrite
|
||||
if !overwrite {
|
||||
if _, err := handler.up.GetInfo(file.Props.SavePath); err == nil {
|
||||
return fs.ErrFileExisted
|
||||
}
|
||||
}
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
err := handler.up.Put(&upyun.PutObjectConfig{
|
||||
Path: file.Props.SavePath,
|
||||
Reader: file,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": mimeType,
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
|
||||
failed := make([]string, 0)
|
||||
var lastErr error
|
||||
|
||||
for _, file := range files {
|
||||
if err := handler.up.Delete(&upyun.DeleteObjectConfig{
|
||||
Path: file,
|
||||
Async: true,
|
||||
}); err != nil {
|
||||
filteredErr := strings.ReplaceAll(err.Error(), file, "")
|
||||
if strings.Contains(filteredErr, "Not found") ||
|
||||
strings.Contains(filteredErr, "NoSuchKey") {
|
||||
continue
|
||||
}
|
||||
|
||||
failed = append(failed, file)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return failed, lastErr
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
|
||||
w, h := handler.settings.ThumbSize(ctx)
|
||||
|
||||
thumbParam := fmt.Sprintf("!/fwfh/%dx%d", w, h)
|
||||
thumbURL, err := handler.signURL(ctx, e.Source()+thumbParam, nil, expire)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return thumbURL, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
|
||||
query := url.Values{}
|
||||
|
||||
// 如果是下载文件URL
|
||||
if args.IsDownload {
|
||||
query.Add("_upd", args.DisplayName)
|
||||
}
|
||||
|
||||
return handler.signURL(ctx, e.Source(), &query, args.Expire)
|
||||
}
|
||||
|
||||
func (handler *Driver) signURL(ctx context.Context, path string, query *url.Values, expire *time.Time) (string, error) {
|
||||
sourceURL, err := url.Parse(handler.policy.Settings.ProxyServer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileKey, err := url.Parse(url.PathEscape(path))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sourceURL = sourceURL.ResolveReference(fileKey)
|
||||
if query != nil {
|
||||
sourceURL.RawQuery = query.Encode()
|
||||
|
||||
}
|
||||
|
||||
if !handler.policy.IsPrivate {
|
||||
// 未开启Token防盗链时,直接返回
|
||||
return sourceURL.String(), nil
|
||||
}
|
||||
|
||||
etime := time.Now().Add(time.Duration(24) * time.Hour * 365 * 20).Unix()
|
||||
if expire != nil {
|
||||
etime = expire.Unix()
|
||||
}
|
||||
signStr := fmt.Sprintf(
|
||||
"%s&%d&%s",
|
||||
handler.policy.Settings.Token,
|
||||
etime,
|
||||
sourceURL.Path,
|
||||
)
|
||||
signMd5 := fmt.Sprintf("%x", md5.Sum([]byte(signStr)))
|
||||
finalSign := signMd5[12:20] + strconv.FormatInt(etime, 10)
|
||||
|
||||
// 将签名添加到URL中
|
||||
q := sourceURL.Query()
|
||||
q.Add("_upt", finalSign)
|
||||
sourceURL.RawQuery = q.Encode()
|
||||
|
||||
return sourceURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
|
||||
if _, err := handler.up.GetInfo(file.Props.SavePath); err == nil {
|
||||
return nil, fs.ErrFileExisted
|
||||
}
|
||||
|
||||
// 生成回调地址
|
||||
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
|
||||
apiUrl := routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeUpyun, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
|
||||
|
||||
// 上传策略
|
||||
putPolicy := UploadPolicy{
|
||||
Bucket: handler.policy.BucketName,
|
||||
SaveKey: file.Props.SavePath,
|
||||
Expiration: uploadSession.Props.ExpireAt.Unix(),
|
||||
CallbackURL: apiUrl,
|
||||
ContentLength: uint64(file.Props.Size),
|
||||
ContentLengthRange: fmt.Sprintf("0,%d", file.Props.Size),
|
||||
}
|
||||
|
||||
// 生成上传凭证
|
||||
policyJSON, err := json.Marshal(putPolicy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
|
||||
|
||||
// 生成签名
|
||||
elements := []string{"POST", "/" + handler.policy.BucketName, policyEncoded}
|
||||
signStr := sign(handler.policy.AccessKey, handler.policy.SecretKey, elements)
|
||||
|
||||
mimeType := file.Props.MimeType
|
||||
if mimeType == "" {
|
||||
handler.mime.TypeByName(file.Props.Uri.Name())
|
||||
}
|
||||
|
||||
return &fs.UploadCredential{
|
||||
UploadPolicy: policyEncoded,
|
||||
UploadURLs: []string{"https://v0.api.upyun.com/" + handler.policy.BucketName},
|
||||
Credential: signStr,
|
||||
MimeType: mimeType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Driver) Capabilities() *driver.Capabilities {
|
||||
mediaMetaExts := handler.policy.Settings.MediaMetaExts
|
||||
if !handler.policy.Settings.NativeMediaProcessing {
|
||||
mediaMetaExts = nil
|
||||
}
|
||||
return &driver.Capabilities{
|
||||
StaticFeatures: features,
|
||||
MediaMetaSupportedExts: mediaMetaExts,
|
||||
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
|
||||
ThumbSupportedExts: handler.policy.Settings.ThumbExts,
|
||||
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
|
||||
ThumbMaxSize: handler.policy.Settings.ThumbMaxSize,
|
||||
ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
|
||||
return handler.extractImageMeta(ctx, path)
|
||||
}
|
||||
|
||||
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func ValidateCallback(c *gin.Context, session *fs.UploadSession) error {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
contentMD5 := c.Request.Header.Get("Content-Md5")
|
||||
date := c.Request.Header.Get("Date")
|
||||
actualSignature := c.Request.Header.Get("Authorization")
|
||||
actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body))
|
||||
if actualContentMD5 != contentMD5 {
|
||||
return errors.New("MD5 mismatch")
|
||||
}
|
||||
|
||||
// Compare signature
|
||||
signature := sign(session.Policy.AccessKey, session.Policy.SecretKey, []string{
|
||||
"POST",
|
||||
c.Request.URL.Path,
|
||||
date,
|
||||
contentMD5,
|
||||
})
|
||||
if signature != actualSignature {
|
||||
return errors.New("Signature not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign 计算又拍云的签名头
|
||||
func sign(ak, sk string, elements []string) string {
|
||||
password := fmt.Sprintf("%x", md5.Sum([]byte(sk)))
|
||||
mac := hmac.New(sha1.New, []byte(password))
|
||||
value := strings.Join(elements, "&")
|
||||
mac.Write([]byte(value))
|
||||
signStr := base64.StdEncoding.EncodeToString((mac.Sum(nil)))
|
||||
return fmt.Sprintf("UPYUN %s:%s", ak, signStr)
|
||||
}
|
||||
37
pkg/filemanager/driver/util.go
Normal file
37
pkg/filemanager/driver/util.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ApplyProxyIfNeeded(policy *ent.StoragePolicy, srcUrl *url.URL) (*url.URL, error) {
|
||||
// For custom proxy, generate a new proxyed URL:
|
||||
// [Proxy Scheme][Proxy Host][Proxy Port][ProxyPath + OriginSrcPath][OriginSrcQuery + ProxyQuery]
|
||||
if policy.Settings.CustomProxy {
|
||||
proxy, err := url.Parse(policy.Settings.ProxyServer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
|
||||
}
|
||||
proxy.Path = path.Join(proxy.Path, strings.TrimPrefix(srcUrl.Path, "/"))
|
||||
q := proxy.Query()
|
||||
if len(q) == 0 {
|
||||
proxy.RawQuery = srcUrl.RawQuery
|
||||
} else {
|
||||
// Merge query parameters
|
||||
srcQ := srcUrl.Query()
|
||||
for k, _ := range srcQ {
|
||||
q.Set(k, srcQ.Get(k))
|
||||
}
|
||||
|
||||
proxy.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
srcUrl = proxy
|
||||
}
|
||||
|
||||
return srcUrl, nil
|
||||
}
|
||||
877
pkg/filemanager/fs/dbfs/dbfs.go
Normal file
877
pkg/filemanager/fs/dbfs/dbfs.go
Normal file
@@ -0,0 +1,877 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/tools/container/intsets"
|
||||
)
|
||||
|
||||
const (
|
||||
ContextHintHeader = constants.CrHeaderPrefix + "Context-Hint"
|
||||
NavigatorStateCachePrefix = "navigator_state_"
|
||||
ContextHintTTL = 5 * 60 // 5 minutes
|
||||
|
||||
folderSummaryCachePrefix = "folder_summary_"
|
||||
)
|
||||
|
||||
type (
|
||||
ContextHintCtxKey struct{}
|
||||
ByPassOwnerCheckCtxKey struct{}
|
||||
)
|
||||
|
||||
func NewDatabaseFS(u *ent.User, fileClient inventory.FileClient, shareClient inventory.ShareClient,
|
||||
l logging.Logger, ls lock.LockSystem, settingClient setting.Provider,
|
||||
storagePolicyClient inventory.StoragePolicyClient, hasher hashid.Encoder, userClient inventory.UserClient,
|
||||
cache, stateKv cache.Driver) fs.FileSystem {
|
||||
return &DBFS{
|
||||
user: u,
|
||||
navigators: make(map[string]Navigator),
|
||||
fileClient: fileClient,
|
||||
shareClient: shareClient,
|
||||
l: l,
|
||||
ls: ls,
|
||||
settingClient: settingClient,
|
||||
storagePolicyClient: storagePolicyClient,
|
||||
hasher: hasher,
|
||||
userClient: userClient,
|
||||
cache: cache,
|
||||
stateKv: stateKv,
|
||||
}
|
||||
}
|
||||
|
||||
type DBFS struct {
|
||||
user *ent.User
|
||||
navigators map[string]Navigator
|
||||
fileClient inventory.FileClient
|
||||
userClient inventory.UserClient
|
||||
storagePolicyClient inventory.StoragePolicyClient
|
||||
shareClient inventory.ShareClient
|
||||
l logging.Logger
|
||||
ls lock.LockSystem
|
||||
settingClient setting.Provider
|
||||
hasher hashid.Encoder
|
||||
cache cache.Driver
|
||||
stateKv cache.Driver
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (f *DBFS) Recycle() {
|
||||
for _, navigator := range f.navigators {
|
||||
navigator.Recycle()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DBFS) GetEntity(ctx context.Context, entityID int) (fs.Entity, error) {
|
||||
if entityID == 0 {
|
||||
return fs.NewEmptyEntity(f.user), nil
|
||||
}
|
||||
|
||||
files, _, err := f.fileClient.GetEntitiesByIDs(ctx, []int{entityID}, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get entity: %w", err)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return nil, fs.ErrEntityNotExist
|
||||
}
|
||||
|
||||
return fs.NewEntity(files[0]), nil
|
||||
|
||||
}
|
||||
|
||||
func (f *DBFS) List(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, *fs.ListFileResult, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityListChildren)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
searchParams := path.SearchParameters()
|
||||
isSearching := searchParams != nil
|
||||
|
||||
// Validate pagination args
|
||||
props := navigator.Capabilities(isSearching)
|
||||
if o.PageSize > props.MaxPageSize {
|
||||
o.PageSize = props.MaxPageSize
|
||||
}
|
||||
|
||||
parent, err := f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Parent not exist: %w", err)
|
||||
}
|
||||
|
||||
var hintId *uuid.UUID
|
||||
if o.generateContextHint {
|
||||
newHintId := uuid.Must(uuid.NewV4())
|
||||
hintId = &newHintId
|
||||
}
|
||||
|
||||
if o.loadFilePublicMetadata {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true)
|
||||
}
|
||||
if o.loadFileShareIfOwned && parent != nil && parent.OwnerID() == f.user.ID {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileShare{}, true)
|
||||
}
|
||||
|
||||
var streamCallback func([]*File)
|
||||
if o.streamListResponseCallback != nil {
|
||||
streamCallback = func(files []*File) {
|
||||
o.streamListResponseCallback(parent, lo.Map(files, func(item *File, index int) fs.File {
|
||||
return item
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
children, err := navigator.Children(ctx, parent, &ListArgs{
|
||||
Page: &inventory.PaginationArgs{
|
||||
Page: o.FsOption.Page,
|
||||
PageSize: o.PageSize,
|
||||
OrderBy: o.OrderBy,
|
||||
Order: inventory.OrderDirection(o.OrderDirection),
|
||||
UseCursorPagination: o.useCursorPagination,
|
||||
PageToken: o.pageToken,
|
||||
},
|
||||
Search: searchParams,
|
||||
StreamCallback: streamCallback,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get children: %w", err)
|
||||
}
|
||||
|
||||
var storagePolicy *ent.StoragePolicy
|
||||
if parent != nil {
|
||||
storagePolicy, err = f.getPreferredPolicy(ctx, parent)
|
||||
if err != nil {
|
||||
f.l.Warning("Failed to get preferred policy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return parent, &fs.ListFileResult{
|
||||
Files: lo.Map(children.Files, func(item *File, index int) fs.File {
|
||||
return item
|
||||
}),
|
||||
Props: props,
|
||||
Pagination: children.Pagination,
|
||||
ContextHint: hintId,
|
||||
RecursionLimitReached: children.RecursionLimitReached,
|
||||
MixedType: children.MixedType,
|
||||
SingleFileView: children.SingleFileView,
|
||||
Parent: parent,
|
||||
StoragePolicy: storagePolicy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Capacity(ctx context.Context, u *ent.User) (*fs.Capacity, error) {
|
||||
// First, get user's available storage packs
|
||||
var (
|
||||
res = &fs.Capacity{}
|
||||
)
|
||||
|
||||
requesterGroup, err := u.Edges.GroupOrErr()
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to get user's group", err)
|
||||
}
|
||||
|
||||
res.Used = f.user.Storage
|
||||
res.Total = requesterGroup.MaxStorage
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) CreateEntity(ctx context.Context, file fs.File, policy *ent.StoragePolicy,
|
||||
entityType types.EntityType, req *fs.UploadRequest, opts ...fs.Option) (fs.Entity, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
// If uploader specified previous latest version ID (etag), we should check if it's still valid.
|
||||
if o.previousVersion != "" {
|
||||
entityId, err := f.hasher.Decode(o.previousVersion, hashid.EntityID)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeParamErr, "Unknown version ID", err)
|
||||
}
|
||||
|
||||
entities, err := file.(*File).Model.Edges.EntitiesOrErr()
|
||||
if err != nil || entities == nil {
|
||||
return nil, fmt.Errorf("create entity: previous entities not load")
|
||||
}
|
||||
|
||||
// File is stale during edit if the latest entity is not the same as the one specified by uploader.
|
||||
if e := file.PrimaryEntity(); e == nil || e.ID() != entityId {
|
||||
return nil, fs.ErrStaleVersion
|
||||
}
|
||||
}
|
||||
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
fileModel := file.(*File).Model
|
||||
if o.removeStaleEntities {
|
||||
storageDiff, err := fc.RemoveStaleEntities(ctx, fileModel)
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to remove stale entities", err)
|
||||
}
|
||||
|
||||
tx.AppendStorageDiff(storageDiff)
|
||||
}
|
||||
|
||||
entity, storageDiff, err := fc.CreateEntity(ctx, fileModel, &inventory.EntityParameters{
|
||||
OwnerID: file.(*File).Owner().ID,
|
||||
EntityType: entityType,
|
||||
StoragePolicyID: policy.ID,
|
||||
Source: req.Props.SavePath,
|
||||
Size: req.Props.Size,
|
||||
UploadSessionID: uuid.FromStringOrNil(o.UploadRequest.Props.UploadSessionID),
|
||||
})
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to create entity", err)
|
||||
}
|
||||
tx.AppendStorageDiff(storageDiff)
|
||||
|
||||
if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit create change", err)
|
||||
}
|
||||
|
||||
return fs.NewEntity(entity), nil
|
||||
}
|
||||
|
||||
func (f *DBFS) PatchMetadata(ctx context.Context, path []*fs.URI, metas ...fs.MetadataPatch) error {
|
||||
ae := serializer.NewAggregateError()
|
||||
targets := make([]*File, 0, len(path))
|
||||
for _, p := range path {
|
||||
navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityUpdateMetadata, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
target, err := f.getFileByPath(ctx, navigator, p)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), fmt.Errorf("failed to get target file: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Require Update permission
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.OwnerID() != f.user.ID {
|
||||
return fs.ErrOwnerOnly.WithError(fmt.Errorf("permission denied"))
|
||||
}
|
||||
|
||||
if target.IsRootFolder() {
|
||||
ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move root folder")))
|
||||
continue
|
||||
}
|
||||
|
||||
targets = append(targets, target)
|
||||
}
|
||||
|
||||
if len(targets) == 0 {
|
||||
return ae.Aggregate()
|
||||
}
|
||||
|
||||
// Lock all targets
|
||||
lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath {
|
||||
return &LockByPath{value.Uri(true), value, value.Type(), ""}
|
||||
})
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, true, fs.LockApp(fs.ApplicationUpdateMetadata), lockTargets...)
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metadataMap := make(map[string]string)
|
||||
privateMap := make(map[string]bool)
|
||||
deleted := make([]string, 0)
|
||||
for _, meta := range metas {
|
||||
if meta.Remove {
|
||||
deleted = append(deleted, meta.Key)
|
||||
continue
|
||||
}
|
||||
metadataMap[meta.Key] = meta.Value
|
||||
if meta.Private {
|
||||
privateMap[meta.Key] = meta.Private
|
||||
}
|
||||
}
|
||||
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if err := fc.UpsertMetadata(ctx, target.Model, metadataMap, privateMap); err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return fmt.Errorf("failed to upsert metadata: %w", err)
|
||||
}
|
||||
|
||||
if len(deleted) > 0 {
|
||||
if err := fc.RemoveMetadata(ctx, target.Model, deleted...); err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return fmt.Errorf("failed to remove metadata: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := inventory.Commit(tx); err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to commit metadata change", err)
|
||||
}
|
||||
|
||||
return ae.Aggregate()
|
||||
}
|
||||
|
||||
func (f *DBFS) SharedAddressTranslation(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, *fs.URI, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, o.requiredCapabilities...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true)
|
||||
if o.loadFileEntities {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true)
|
||||
}
|
||||
|
||||
uriTranslation := func(target *File, rebase bool) (fs.File, *fs.URI, error) {
|
||||
// Translate shared address to real address
|
||||
metadata := target.Metadata()
|
||||
if metadata == nil {
|
||||
if err := f.fileClient.QueryMetadata(ctx, target.Model); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query metadata: %w", err)
|
||||
}
|
||||
metadata = target.Metadata()
|
||||
}
|
||||
redirect, ok := metadata[MetadataSharedRedirect]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("missing metadata %s in symbolic folder %s", MetadataSharedRedirect, path)
|
||||
}
|
||||
|
||||
redirectUri, err := fs.NewUriFromString(redirect)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid redirect uri %s in symbolic folder %s", redirect, path)
|
||||
}
|
||||
newUri := redirectUri
|
||||
if rebase {
|
||||
newUri = redirectUri.Rebase(path, target.Uri(false))
|
||||
}
|
||||
return f.SharedAddressTranslation(ctx, newUri, opts...)
|
||||
}
|
||||
|
||||
target, err := f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSymbolicFolderFound) && target.Type() == types.FileTypeFolder {
|
||||
return uriTranslation(target, true)
|
||||
}
|
||||
|
||||
if !ent.IsNotFound(err) {
|
||||
return nil, nil, fmt.Errorf("failed to get target file: %w", err)
|
||||
}
|
||||
|
||||
// Request URI does not exist, return most recent ancestor
|
||||
return target, path, err
|
||||
}
|
||||
|
||||
if target.IsSymbolic() {
|
||||
return uriTranslation(target, false)
|
||||
}
|
||||
|
||||
return target, path, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Get(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.File, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, o.requiredCapabilities...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if o.loadFilePublicMetadata || o.extendedInfo {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true)
|
||||
}
|
||||
|
||||
if o.loadFileEntities || o.extendedInfo || o.loadFolderSummary {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true)
|
||||
}
|
||||
|
||||
if o.loadFileShareIfOwned {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileShare{}, true)
|
||||
}
|
||||
|
||||
if o.loadEntityUser {
|
||||
ctx = context.WithValue(ctx, inventory.LoadEntityUser{}, true)
|
||||
}
|
||||
|
||||
// Get target file
|
||||
target, err := f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get target file: %w", err)
|
||||
}
|
||||
|
||||
if o.extendedInfo && target != nil {
|
||||
extendedInfo := &fs.FileExtendedInfo{
|
||||
StorageUsed: target.SizeUsed(),
|
||||
EntityStoragePolicies: make(map[int]*ent.StoragePolicy),
|
||||
}
|
||||
policyID := target.PolicyID()
|
||||
if policyID > 0 {
|
||||
policy, err := f.storagePolicyClient.GetPolicyByID(ctx, policyID)
|
||||
if err == nil {
|
||||
extendedInfo.StoragePolicy = policy
|
||||
}
|
||||
}
|
||||
|
||||
target.FileExtendedInfo = extendedInfo
|
||||
if target.OwnerID() == f.user.ID || f.user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) {
|
||||
target.FileExtendedInfo.Shares = target.Model.Edges.Shares
|
||||
}
|
||||
|
||||
entities := target.Entities()
|
||||
for _, entity := range entities {
|
||||
if _, ok := extendedInfo.EntityStoragePolicies[entity.PolicyID()]; !ok {
|
||||
policy, err := f.storagePolicyClient.GetPolicyByID(ctx, entity.PolicyID())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get policy: %w", err)
|
||||
}
|
||||
|
||||
extendedInfo.EntityStoragePolicies[entity.PolicyID()] = policy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate folder summary if requested
|
||||
if o.loadFolderSummary && target != nil && target.Type() == types.FileTypeFolder {
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.OwnerID() != f.user.ID {
|
||||
return nil, fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
// first, try to load from cache
|
||||
summary, ok := f.cache.Get(fmt.Sprintf("%s%d", folderSummaryCachePrefix, target.ID()))
|
||||
if ok {
|
||||
summaryTyped := summary.(fs.FolderSummary)
|
||||
target.FileFolderSummary = &summaryTyped
|
||||
} else {
|
||||
// cache miss, walk the folder to get the summary
|
||||
newSummary := &fs.FolderSummary{Completed: true}
|
||||
if f.user.Edges.Group == nil {
|
||||
return nil, fmt.Errorf("user group not loaded")
|
||||
}
|
||||
limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1)
|
||||
|
||||
// disable load metadata to speed up
|
||||
ctxWalk := context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, false)
|
||||
if err := navigator.Walk(ctxWalk, []*File{target}, limit, intsets.MaxInt, func(files []*File, l int) error {
|
||||
for _, file := range files {
|
||||
if file.ID() == target.ID() {
|
||||
continue
|
||||
}
|
||||
if file.Type() == types.FileTypeFile {
|
||||
newSummary.Files++
|
||||
} else {
|
||||
newSummary.Folders++
|
||||
}
|
||||
|
||||
newSummary.Size += file.SizeUsed()
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
if !errors.Is(err, ErrFileCountLimitedReached) {
|
||||
return nil, fmt.Errorf("failed to walk: %w", err)
|
||||
}
|
||||
|
||||
newSummary.Completed = false
|
||||
}
|
||||
|
||||
// cache the summary
|
||||
newSummary.CalculatedAt = time.Now()
|
||||
f.cache.Set(fmt.Sprintf("%s%d", folderSummaryCachePrefix, target.ID()), newSummary, f.settingClient.FolderPropsCacheTTL(ctx))
|
||||
target.FileFolderSummary = newSummary
|
||||
}
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
return nil, fmt.Errorf("cannot get root file with nil root")
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) CheckCapability(ctx context.Context, uri *fs.URI, opts ...fs.Option) error {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
// Get navigator
|
||||
_, err := f.getNavigator(ctx, uri, o.requiredCapabilities...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Walk(ctx context.Context, path *fs.URI, depth int, walk fs.WalkFunc, opts ...fs.Option) error {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
if o.loadFilePublicMetadata {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true)
|
||||
}
|
||||
|
||||
if o.loadFileEntities {
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true)
|
||||
}
|
||||
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, o.requiredCapabilities...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
target, err := f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Require Read permission
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.OwnerID() != f.user.ID {
|
||||
return fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
// Walk
|
||||
if f.user.Edges.Group == nil {
|
||||
return fmt.Errorf("user group not loaded")
|
||||
}
|
||||
limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1)
|
||||
|
||||
if err := navigator.Walk(ctx, []*File{target}, limit, depth, func(files []*File, l int) error {
|
||||
for _, file := range files {
|
||||
if err := walk(file, l); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to walk: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DBFS) ExecuteNavigatorHooks(ctx context.Context, hookType fs.HookType, file fs.File) error {
|
||||
navigator, err := f.getNavigator(ctx, file.Uri(false))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dbfsFile, ok := file.(*File); ok {
|
||||
return navigator.ExecuteHook(ctx, hookType, dbfsFile)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createFile creates a file with given name and type under given parent folder
|
||||
func (f *DBFS) createFile(ctx context.Context, parent *File, name string, fileType types.FileType, o *dbfsOption) (*File, error) {
|
||||
createFileArgs := &inventory.CreateFileParameters{
|
||||
FileType: fileType,
|
||||
Name: name,
|
||||
MetadataPrivateMask: make(map[string]bool),
|
||||
Metadata: make(map[string]string),
|
||||
IsSymbolic: o.isSymbolicLink,
|
||||
}
|
||||
|
||||
if o.Metadata != nil {
|
||||
for k, v := range o.Metadata {
|
||||
createFileArgs.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if o.preferredStoragePolicy != nil {
|
||||
createFileArgs.StoragePolicyID = o.preferredStoragePolicy.ID
|
||||
} else {
|
||||
// get preferred storage policy
|
||||
policy, err := f.getPreferredPolicy(ctx, parent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createFileArgs.StoragePolicyID = policy.ID
|
||||
}
|
||||
|
||||
if o.UploadRequest != nil {
|
||||
createFileArgs.EntityParameters = &inventory.EntityParameters{
|
||||
EntityType: types.EntityTypeVersion,
|
||||
Source: o.UploadRequest.Props.SavePath,
|
||||
Size: o.UploadRequest.Props.Size,
|
||||
ModifiedAt: o.UploadRequest.Props.LastModified,
|
||||
UploadSessionID: uuid.FromStringOrNil(o.UploadRequest.Props.UploadSessionID),
|
||||
}
|
||||
}
|
||||
|
||||
// Start transaction to create files
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
file, entity, storageDiff, err := fc.CreateFile(ctx, parent.Model, createFileArgs)
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
if ent.IsConstraintError(err) {
|
||||
return nil, fs.ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to create file", err)
|
||||
}
|
||||
|
||||
tx.AppendStorageDiff(storageDiff)
|
||||
if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit create change", err)
|
||||
}
|
||||
|
||||
file.SetEntities([]*ent.Entity{entity})
|
||||
return newFile(parent, file), nil
|
||||
}
|
||||
|
||||
// getPreferredPolicy tries to get the preferred storage policy for the given file.
|
||||
func (f *DBFS) getPreferredPolicy(ctx context.Context, file *File) (*ent.StoragePolicy, error) {
|
||||
ownerGroup := file.Owner().Edges.Group
|
||||
if ownerGroup == nil {
|
||||
return nil, fmt.Errorf("owner group not loaded")
|
||||
}
|
||||
|
||||
groupPolicy, err := f.storagePolicyClient.GetByGroup(ctx, ownerGroup)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to get available storage policies", err)
|
||||
}
|
||||
|
||||
return groupPolicy, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) getFileByPath(ctx context.Context, navigator Navigator, path *fs.URI) (*File, error) {
|
||||
file, err := navigator.To(ctx, path)
|
||||
if err != nil && errors.Is(err, ErrFsNotInitialized) {
|
||||
// Initialize file system for user if root folder does not exist.
|
||||
uid := path.ID(hashid.EncodeUserID(f.hasher, f.user.ID))
|
||||
uidInt, err := f.hasher.Decode(uid, hashid.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user ID: %w", err)
|
||||
}
|
||||
|
||||
if err := f.initFs(ctx, uidInt); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize file system: %w", err)
|
||||
}
|
||||
return navigator.To(ctx, path)
|
||||
}
|
||||
|
||||
return file, err
|
||||
}
|
||||
|
||||
// initFs initializes the file system for the user.
|
||||
func (f *DBFS) initFs(ctx context.Context, uid int) error {
|
||||
f.l.Info("Initialize database file system for user %q", f.user.Email)
|
||||
_, err := f.fileClient.CreateFolder(ctx, nil,
|
||||
&inventory.CreateFolderParameters{
|
||||
Owner: uid,
|
||||
Name: inventory.RootFolderName,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create root folder: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DBFS) getNavigator(ctx context.Context, path *fs.URI, requiredCapabilities ...NavigatorCapability) (Navigator, error) {
|
||||
pathFs := path.FileSystem()
|
||||
config := f.settingClient.DBFS(ctx)
|
||||
navigatorId := f.navigatorId(path)
|
||||
var (
|
||||
res Navigator
|
||||
)
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if navigator, ok := f.navigators[navigatorId]; ok {
|
||||
res = navigator
|
||||
} else {
|
||||
var n Navigator
|
||||
switch pathFs {
|
||||
case constants.FileSystemMy:
|
||||
n = NewMyNavigator(f.user, f.fileClient, f.userClient, f.l, config, f.hasher)
|
||||
case constants.FileSystemShare:
|
||||
n = NewShareNavigator(f.user, f.fileClient, f.shareClient, f.l, config, f.hasher)
|
||||
case constants.FileSystemTrash:
|
||||
n = NewTrashNavigator(f.user, f.fileClient, f.l, config, f.hasher)
|
||||
case constants.FileSystemSharedWithMe:
|
||||
n = NewSharedWithMeNavigator(f.user, f.fileClient, f.l, config, f.hasher)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown file system %q", pathFs)
|
||||
}
|
||||
|
||||
// retrieve state if context hint is provided
|
||||
if stateID, ok := ctx.Value(ContextHintCtxKey{}).(uuid.UUID); ok && stateID != uuid.Nil {
|
||||
cacheKey := NavigatorStateCachePrefix + stateID.String() + "_" + navigatorId
|
||||
if stateRaw, ok := f.stateKv.Get(cacheKey); ok {
|
||||
if err := n.RestoreState(stateRaw.(State)); err != nil {
|
||||
f.l.Warning("Failed to restore state for navigator %q: %s", navigatorId, err)
|
||||
} else {
|
||||
f.l.Info("Navigator %q restored state (%q) successfully", navigatorId, stateID)
|
||||
}
|
||||
} else {
|
||||
// State expire, refresh it
|
||||
n.PersistState(f.stateKv, cacheKey)
|
||||
}
|
||||
}
|
||||
|
||||
f.navigators[navigatorId] = n
|
||||
res = n
|
||||
}
|
||||
|
||||
// Check fs capabilities
|
||||
capabilities := res.Capabilities(false).Capability
|
||||
for _, capability := range requiredCapabilities {
|
||||
if !capabilities.Enabled(int(capability)) {
|
||||
return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("action %q is not supported under current fs", capability))
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) navigatorId(path *fs.URI) string {
|
||||
uidHashed := hashid.EncodeUserID(f.hasher, f.user.ID)
|
||||
switch path.FileSystem() {
|
||||
case constants.FileSystemMy:
|
||||
return fmt.Sprintf("%s/%s/%d", constants.FileSystemMy, path.ID(uidHashed), f.user.ID)
|
||||
case constants.FileSystemShare:
|
||||
return fmt.Sprintf("%s/%s/%d", constants.FileSystemShare, path.ID(uidHashed), f.user.ID)
|
||||
case constants.FileSystemTrash:
|
||||
return fmt.Sprintf("%s/%s", constants.FileSystemTrash, path.ID(uidHashed))
|
||||
default:
|
||||
return fmt.Sprintf("%s/%s/%d", path.FileSystem(), path.ID(uidHashed), f.user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// generateSavePath generates the physical save path for the upload request.
|
||||
func generateSavePath(policy *ent.StoragePolicy, req *fs.UploadRequest, user *ent.User) string {
|
||||
baseTable := map[string]string{
|
||||
"{randomkey16}": util.RandStringRunes(16),
|
||||
"{randomkey8}": util.RandStringRunes(8),
|
||||
"{timestamp}": strconv.FormatInt(time.Now().Unix(), 10),
|
||||
"{timestamp_nano}": strconv.FormatInt(time.Now().UnixNano(), 10),
|
||||
"{randomnum2}": strconv.Itoa(rand.Intn(2)),
|
||||
"{randomnum3}": strconv.Itoa(rand.Intn(3)),
|
||||
"{randomnum4}": strconv.Itoa(rand.Intn(4)),
|
||||
"{randomnum8}": strconv.Itoa(rand.Intn(8)),
|
||||
"{uid}": strconv.Itoa(user.ID),
|
||||
"{datetime}": time.Now().Format("20060102150405"),
|
||||
"{date}": time.Now().Format("20060102"),
|
||||
"{year}": time.Now().Format("2006"),
|
||||
"{month}": time.Now().Format("01"),
|
||||
"{day}": time.Now().Format("02"),
|
||||
"{hour}": time.Now().Format("15"),
|
||||
"{minute}": time.Now().Format("04"),
|
||||
"{second}": time.Now().Format("05"),
|
||||
}
|
||||
|
||||
dirRule := policy.DirNameRule
|
||||
dirRule = filepath.ToSlash(dirRule)
|
||||
dirRule = util.Replace(baseTable, dirRule)
|
||||
dirRule = util.Replace(map[string]string{
|
||||
"{path}": req.Props.Uri.Dir() + fs.Separator,
|
||||
}, dirRule)
|
||||
|
||||
originName := req.Props.Uri.Name()
|
||||
nameTable := map[string]string{
|
||||
"{originname}": originName,
|
||||
"{ext}": filepath.Ext(originName),
|
||||
"{originname_without_ext}": strings.TrimSuffix(originName, filepath.Ext(originName)),
|
||||
"{uuid}": uuid.Must(uuid.NewV4()).String(),
|
||||
}
|
||||
|
||||
nameRule := policy.FileNameRule
|
||||
nameRule = util.Replace(baseTable, nameRule)
|
||||
nameRule = util.Replace(nameTable, nameRule)
|
||||
|
||||
return path.Join(path.Clean(dirRule), nameRule)
|
||||
}
|
||||
|
||||
func canMoveOrCopyTo(src, dst *fs.URI, isCopy bool) bool {
|
||||
if isCopy {
|
||||
return src.FileSystem() == dst.FileSystem() && src.FileSystem() == constants.FileSystemMy
|
||||
} else {
|
||||
switch src.FileSystem() {
|
||||
case constants.FileSystemMy:
|
||||
return dst.FileSystem() == constants.FileSystemMy || dst.FileSystem() == constants.FileSystemTrash
|
||||
case constants.FileSystemTrash:
|
||||
return dst.FileSystem() == constants.FileSystemMy
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func allAncestors(targets []*File) []*ent.File {
|
||||
return lo.Map(
|
||||
lo.UniqBy(
|
||||
lo.FlatMap(targets, func(value *File, index int) []*File {
|
||||
return value.Ancestors()
|
||||
}),
|
||||
func(item *File) int {
|
||||
return item.ID()
|
||||
},
|
||||
),
|
||||
func(item *File, index int) *ent.File {
|
||||
return item.Model
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func WithBypassOwnerCheck(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, ByPassOwnerCheckCtxKey{}, true)
|
||||
}
|
||||
335
pkg/filemanager/fs/dbfs/file.go
Normal file
335
pkg/filemanager/fs/dbfs/file.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(File{})
|
||||
gob.Register(shareNavigatorState{})
|
||||
gob.Register(map[string]*File{})
|
||||
gob.Register(map[int]*File{})
|
||||
}
|
||||
|
||||
var filePool = &sync.Pool{
|
||||
New: func() any {
|
||||
return &File{
|
||||
Children: make(map[string]*File),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
type (
|
||||
File struct {
|
||||
Model *ent.File
|
||||
Children map[string]*File
|
||||
Parent *File
|
||||
Path [2]*fs.URI
|
||||
OwnerModel *ent.User
|
||||
IsUserRoot bool
|
||||
CapabilitiesBs *boolset.BooleanSet
|
||||
FileExtendedInfo *fs.FileExtendedInfo
|
||||
FileFolderSummary *fs.FolderSummary
|
||||
|
||||
mu *sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
MetadataSysPrefix = "sys:"
|
||||
MetadataUploadSessionPrefix = MetadataSysPrefix + "upload_session"
|
||||
MetadataUploadSessionID = MetadataUploadSessionPrefix + "_id"
|
||||
MetadataSharedRedirect = MetadataSysPrefix + "shared_redirect"
|
||||
MetadataRestoreUri = MetadataSysPrefix + "restore_uri"
|
||||
MetadataExpectedCollectTime = MetadataSysPrefix + "expected_collect_time"
|
||||
|
||||
ThumbMetadataPrefix = "thumb:"
|
||||
ThumbDisabledKey = ThumbMetadataPrefix + "disabled"
|
||||
|
||||
pathIndexRoot = 0
|
||||
pathIndexUser = 1
|
||||
)
|
||||
|
||||
func (f *File) Name() string {
|
||||
return f.Model.Name
|
||||
}
|
||||
|
||||
func (f *File) IsNil() bool {
|
||||
return f == nil
|
||||
}
|
||||
|
||||
func (f *File) DisplayName() string {
|
||||
if uri, ok := f.Metadata()[MetadataRestoreUri]; ok {
|
||||
restoreUri, err := fs.NewUriFromString(uri)
|
||||
if err != nil {
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
return path.Base(restoreUri.Path())
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func (f *File) CanHaveChildren() bool {
|
||||
return f.Type() == types.FileTypeFolder && !f.IsSymbolic()
|
||||
}
|
||||
|
||||
func (f *File) Ext() string {
|
||||
return util.Ext(f.Name())
|
||||
}
|
||||
|
||||
func (f *File) ID() int {
|
||||
return f.Model.ID
|
||||
}
|
||||
|
||||
func (f *File) IsSymbolic() bool {
|
||||
return f.Model.IsSymbolic
|
||||
}
|
||||
|
||||
func (f *File) Type() types.FileType {
|
||||
return types.FileType(f.Model.Type)
|
||||
}
|
||||
|
||||
func (f *File) Size() int64 {
|
||||
return f.Model.Size
|
||||
}
|
||||
|
||||
func (f *File) SizeUsed() int64 {
|
||||
return lo.SumBy(f.Entities(), func(item fs.Entity) int64 {
|
||||
return item.Size()
|
||||
})
|
||||
}
|
||||
|
||||
func (f *File) UpdatedAt() time.Time {
|
||||
return f.Model.UpdatedAt
|
||||
}
|
||||
|
||||
func (f *File) CreatedAt() time.Time {
|
||||
return f.Model.CreatedAt
|
||||
}
|
||||
|
||||
func (f *File) ExtendedInfo() *fs.FileExtendedInfo {
|
||||
return f.FileExtendedInfo
|
||||
}
|
||||
|
||||
func (f *File) Owner() *ent.User {
|
||||
parent := f
|
||||
for parent != nil {
|
||||
if parent.OwnerModel != nil {
|
||||
return parent.OwnerModel
|
||||
}
|
||||
parent = parent.Parent
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *File) OwnerID() int {
|
||||
return f.Model.OwnerID
|
||||
}
|
||||
|
||||
func (f *File) Shared() bool {
|
||||
return len(f.Model.Edges.Shares) > 0
|
||||
}
|
||||
|
||||
func (f *File) Metadata() map[string]string {
|
||||
if f.Model.Edges.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return lo.Associate(f.Model.Edges.Metadata, func(item *ent.Metadata) (string, string) {
|
||||
return item.Name, item.Value
|
||||
})
|
||||
}
|
||||
|
||||
// Uri returns the URI of the file.
|
||||
// If isRoot is true, the URI will be returned from owner's view.
|
||||
// Otherwise, the URI will be returned from user's view.
|
||||
func (f *File) Uri(isRoot bool) *fs.URI {
|
||||
index := 1
|
||||
if isRoot {
|
||||
index = 0
|
||||
}
|
||||
if f.Path[index] != nil || f.Parent == nil {
|
||||
return f.Path[index]
|
||||
}
|
||||
|
||||
// Find the root file
|
||||
elements := make([]string, 0)
|
||||
parent := f
|
||||
for parent.Parent != nil && parent.Path[index] == nil {
|
||||
elements = append([]string{parent.Name()}, elements...)
|
||||
parent = parent.Parent
|
||||
}
|
||||
|
||||
if parent.Path[index] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return parent.Path[index].Join(elements...)
|
||||
}
|
||||
|
||||
// UserRoot return the root file from user's view.
|
||||
func (f *File) UserRoot() *File {
|
||||
root := f
|
||||
for root != nil && !root.IsUserRoot {
|
||||
root = root.Parent
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
// Root return the root file from owner's view.
|
||||
func (f *File) Root() *File {
|
||||
root := f
|
||||
for root.Parent != nil {
|
||||
root = root.Parent
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
// RootUri return the URI of the user root file under owner's view.
|
||||
func (f *File) RootUri() *fs.URI {
|
||||
return f.UserRoot().Uri(true)
|
||||
}
|
||||
|
||||
func (f *File) Replace(model *ent.File) *File {
|
||||
f.mu.Lock()
|
||||
delete(f.Parent.Children, f.Model.Name)
|
||||
f.mu.Unlock()
|
||||
|
||||
defer f.Recycle()
|
||||
replaced := newFile(f.Parent, model)
|
||||
if f.IsRootFile() {
|
||||
// If target is a root file, the user path should remain the same.
|
||||
replaced.Path[pathIndexUser] = f.Path[pathIndexUser]
|
||||
}
|
||||
|
||||
return replaced
|
||||
}
|
||||
|
||||
// Ancestors return all ancestors of the file, until the owner root is reached.
|
||||
func (f *File) Ancestors() []*File {
|
||||
return f.AncestorsChain()[1:]
|
||||
}
|
||||
|
||||
// AncestorsChain return all ancestors of the file (including itself), until the owner root is reached.
|
||||
func (f *File) AncestorsChain() []*File {
|
||||
ancestors := make([]*File, 0)
|
||||
parent := f
|
||||
for parent != nil {
|
||||
ancestors = append(ancestors, parent)
|
||||
parent = parent.Parent
|
||||
}
|
||||
|
||||
return ancestors
|
||||
}
|
||||
|
||||
func (f *File) PolicyID() int {
|
||||
root := f
|
||||
return root.Model.StoragePolicyFiles
|
||||
}
|
||||
|
||||
// IsRootFolder return true if the file is the root folder under user's view.
|
||||
func (f *File) IsRootFolder() bool {
|
||||
return f.Type() == types.FileTypeFolder && f.IsRootFile()
|
||||
}
|
||||
|
||||
// IsRootFile return true if the file is the root file under user's view.
|
||||
func (f *File) IsRootFile() bool {
|
||||
uri := f.Uri(false)
|
||||
p := uri.Path()
|
||||
return f.Model.Name == inventory.RootFolderName || p == fs.Separator || p == ""
|
||||
}
|
||||
|
||||
func (f *File) Entities() []fs.Entity {
|
||||
return lo.Map(f.Model.Edges.Entities, func(item *ent.Entity, index int) fs.Entity {
|
||||
return fs.NewEntity(item)
|
||||
})
|
||||
}
|
||||
|
||||
func (f *File) PrimaryEntity() fs.Entity {
|
||||
primary, _ := lo.Find(f.Model.Edges.Entities, func(item *ent.Entity) bool {
|
||||
return item.Type == int(types.EntityTypeVersion) && item.ID == f.Model.PrimaryEntity
|
||||
})
|
||||
if primary != nil {
|
||||
return fs.NewEntity(primary)
|
||||
}
|
||||
|
||||
return fs.NewEmptyEntity(f.Owner())
|
||||
}
|
||||
|
||||
func (f *File) PrimaryEntityID() int {
|
||||
return f.Model.PrimaryEntity
|
||||
}
|
||||
|
||||
func (f *File) FolderSummary() *fs.FolderSummary {
|
||||
return f.FileFolderSummary
|
||||
}
|
||||
|
||||
func (f *File) Capabilities() *boolset.BooleanSet {
|
||||
return f.CapabilitiesBs
|
||||
}
|
||||
|
||||
func newFile(parent *File, model *ent.File) *File {
|
||||
f := filePool.Get().(*File)
|
||||
f.Model = model
|
||||
|
||||
if parent != nil {
|
||||
f.Parent = parent
|
||||
parent.mu.Lock()
|
||||
parent.Children[model.Name] = f
|
||||
if parent.Path[pathIndexUser] != nil {
|
||||
f.Path[pathIndexUser] = parent.Path[pathIndexUser].Join(model.Name)
|
||||
}
|
||||
|
||||
if parent.Path[pathIndexRoot] != nil {
|
||||
f.Path[pathIndexRoot] = parent.Path[pathIndexRoot].Join(model.Name)
|
||||
}
|
||||
|
||||
f.CapabilitiesBs = parent.CapabilitiesBs
|
||||
f.mu = parent.mu
|
||||
parent.mu.Unlock()
|
||||
} else {
|
||||
f.mu = &sync.Mutex{}
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func newParentFile(parent *ent.File, child *File) *File {
|
||||
newParent := newFile(nil, parent)
|
||||
newParent.Children[child.Name()] = child
|
||||
child.Parent = newParent
|
||||
newParent.mu = child.mu
|
||||
return newParent
|
||||
}
|
||||
|
||||
func (f *File) Recycle() {
|
||||
for _, child := range f.Children {
|
||||
child.Recycle()
|
||||
}
|
||||
|
||||
f.Model = nil
|
||||
f.Children = make(map[string]*File)
|
||||
f.Path[0] = nil
|
||||
f.Path[1] = nil
|
||||
f.Parent = nil
|
||||
f.OwnerModel = nil
|
||||
f.IsUserRoot = false
|
||||
f.mu = nil
|
||||
|
||||
filePool.Put(f)
|
||||
}
|
||||
55
pkg/filemanager/fs/dbfs/global.go
Normal file
55
pkg/filemanager/fs/dbfs/global.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func (f *DBFS) StaleEntities(ctx context.Context, entities ...int) ([]fs.Entity, error) {
|
||||
res, err := f.fileClient.StaleEntities(ctx, entities...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return lo.Map(res, func(e *ent.Entity, i int) fs.Entity {
|
||||
return fs.NewEntity(e)
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (f *DBFS) AllFilesInTrashBin(ctx context.Context, opts ...fs.Option) (*fs.ListFileResult, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
navigator, err := f.getNavigator(ctx, newTrashUri(""), NavigatorCapabilityListChildren)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true)
|
||||
children, err := navigator.Children(ctx, nil, &ListArgs{
|
||||
Page: &inventory.PaginationArgs{
|
||||
Page: o.FsOption.Page,
|
||||
PageSize: o.PageSize,
|
||||
OrderBy: o.OrderBy,
|
||||
Order: inventory.OrderDirection(o.OrderDirection),
|
||||
UseCursorPagination: o.useCursorPagination,
|
||||
PageToken: o.pageToken,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fs.ListFileResult{
|
||||
Files: lo.Map(children.Files, func(item *File, index int) fs.File {
|
||||
return item
|
||||
}),
|
||||
Pagination: children.Pagination,
|
||||
RecursionLimitReached: children.RecursionLimitReached,
|
||||
}, nil
|
||||
}
|
||||
325
pkg/filemanager/fs/dbfs/lock.go
Normal file
325
pkg/filemanager/fs/dbfs/lock.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type (
|
||||
LockSession struct {
|
||||
Tokens map[string]string
|
||||
TokenStack [][]string
|
||||
}
|
||||
|
||||
LockByPath struct {
|
||||
Uri *fs.URI
|
||||
ClosestAncestor *File
|
||||
Type types.FileType
|
||||
Token string
|
||||
}
|
||||
|
||||
AlwaysIncludeTokenCtx struct{}
|
||||
)
|
||||
|
||||
func (f *DBFS) ConfirmLock(ctx context.Context, ancestor fs.File, uri *fs.URI, token ...string) (func(), fs.LockSession, error) {
|
||||
session := LockSessionFromCtx(ctx)
|
||||
lockUri := ancestor.RootUri().JoinRaw(uri.PathTrimmed())
|
||||
ns, root, lKey := lockTupleFromUri(lockUri, f.user, f.hasher)
|
||||
lc := lock.LockInfo{
|
||||
Ns: ns,
|
||||
Root: root,
|
||||
Token: token,
|
||||
}
|
||||
|
||||
// Skip if already locked in current session
|
||||
if _, ok := session.Tokens[lKey]; ok {
|
||||
return func() {}, session, nil
|
||||
}
|
||||
|
||||
release, tokenHit, err := f.ls.Confirm(time.Now(), lc)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
session.Tokens[lKey] = tokenHit
|
||||
stackIndex := len(session.TokenStack) - 1
|
||||
session.TokenStack[stackIndex] = append(session.TokenStack[stackIndex], lKey)
|
||||
return release, session, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Lock(ctx context.Context, d time.Duration, requester *ent.User, zeroDepth bool, application lock.Application,
|
||||
uri *fs.URI, token string) (fs.LockSession, error) {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, uri, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ancestor, err := f.getFileByPath(ctx, navigator, uri)
|
||||
if err != nil && !ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("failed to get ancestor: %w", err)
|
||||
}
|
||||
|
||||
if ancestor.IsRootFolder() && ancestor.Uri(false).IsSame(uri, hashid.EncodeUserID(f.hasher, f.user.ID)) {
|
||||
return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot lock root folder"))
|
||||
}
|
||||
|
||||
// Lock require create or update permission
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && ancestor.Owner().ID != requester.ID {
|
||||
return nil, fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
t := types.FileTypeFile
|
||||
if ancestor.Uri(false).IsSame(uri, hashid.EncodeUserID(f.hasher, f.user.ID)) {
|
||||
t = ancestor.Type()
|
||||
}
|
||||
lr := &LockByPath{
|
||||
Uri: ancestor.RootUri().JoinRaw(uri.PathTrimmed()),
|
||||
ClosestAncestor: ancestor,
|
||||
Type: t,
|
||||
Token: token,
|
||||
}
|
||||
ls, err := f.acquireByPath(ctx, d, requester, zeroDepth, application, lr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ls, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Unlock(ctx context.Context, tokens ...string) error {
|
||||
return f.ls.Unlock(time.Now(), tokens...)
|
||||
}
|
||||
|
||||
func (f *DBFS) Refresh(ctx context.Context, d time.Duration, token string) (lock.LockDetails, error) {
|
||||
return f.ls.Refresh(time.Now(), d, token)
|
||||
}
|
||||
|
||||
func (f *DBFS) acquireByPath(ctx context.Context, duration time.Duration,
|
||||
requester *ent.User, zeroDepth bool, application lock.Application, locks ...*LockByPath) (*LockSession, error) {
|
||||
session := LockSessionFromCtx(ctx)
|
||||
|
||||
// Prepare lock details for each file
|
||||
lockDetails := make([]lock.LockDetails, 0, len(locks))
|
||||
lockedRequest := make([]*LockByPath, 0, len(locks))
|
||||
for _, l := range locks {
|
||||
ns, root, lKey := lockTupleFromUri(l.Uri, f.user, f.hasher)
|
||||
ld := lock.LockDetails{
|
||||
Owner: lock.Owner{
|
||||
Application: application,
|
||||
},
|
||||
Ns: ns,
|
||||
Root: root,
|
||||
ZeroDepth: zeroDepth,
|
||||
Duration: duration,
|
||||
Type: l.Type,
|
||||
Token: l.Token,
|
||||
}
|
||||
|
||||
// Skip if already locked in current session
|
||||
if _, ok := session.Tokens[lKey]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
lockDetails = append(lockDetails, ld)
|
||||
lockedRequest = append(lockedRequest, l)
|
||||
}
|
||||
|
||||
// Acquire lock
|
||||
tokens, err := f.ls.Create(time.Now(), lockDetails...)
|
||||
if len(tokens) > 0 {
|
||||
for i, token := range tokens {
|
||||
key := lockDetails[i].Key()
|
||||
session.Tokens[key] = token
|
||||
stackIndex := len(session.TokenStack) - 1
|
||||
session.TokenStack[stackIndex] = append(session.TokenStack[stackIndex], key)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var conflicts lock.ConflictError
|
||||
if errors.As(err, &conflicts) {
|
||||
// Conflict with existing lock, generate user-friendly error message
|
||||
conflicts = lo.Map(conflicts, func(c *lock.ConflictDetail, index int) *lock.ConflictDetail {
|
||||
lr := lockedRequest[c.Index]
|
||||
if lr.ClosestAncestor.Root().Model.OwnerID == requester.ID {
|
||||
// Add absolute path for owner issued lock request
|
||||
c.Path = newMyUri().JoinRaw(c.Path).String()
|
||||
return c
|
||||
}
|
||||
|
||||
// Hide token for non-owner requester
|
||||
if v, ok := ctx.Value(AlwaysIncludeTokenCtx{}).(bool); !ok || !v {
|
||||
c.Token = ""
|
||||
}
|
||||
|
||||
// If conflicted resources still under user root, expose the relative path
|
||||
userRoot := lr.ClosestAncestor.UserRoot()
|
||||
userRootPath := userRoot.Uri(true).Path()
|
||||
if strings.HasPrefix(c.Path, userRootPath) {
|
||||
c.Path = userRoot.
|
||||
Uri(false).
|
||||
Join(strings.Split(strings.TrimPrefix(c.Path, userRootPath), fs.Separator)...).String()
|
||||
return c
|
||||
}
|
||||
|
||||
// Hide sensitive information for non-owner issued lock request
|
||||
c.Path = ""
|
||||
return c
|
||||
})
|
||||
|
||||
return session, fs.ErrLockConflict.WithError(conflicts)
|
||||
}
|
||||
|
||||
return session, fmt.Errorf("faield to create lock: %w", err)
|
||||
}
|
||||
|
||||
// Check if any ancestor is modified during `getFileByPath` and `lock`.
|
||||
if err := f.ensureConsistency(
|
||||
ctx,
|
||||
lo.Map(lockedRequest, func(item *LockByPath, index int) *File {
|
||||
return item.ClosestAncestor
|
||||
})...,
|
||||
); err != nil {
|
||||
return session, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Release(ctx context.Context, session *LockSession) error {
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
stackIndex := len(session.TokenStack) - 1
|
||||
err := f.ls.Unlock(time.Now(), lo.Map(session.TokenStack[stackIndex], func(key string, index int) string {
|
||||
return session.Tokens[key]
|
||||
})...)
|
||||
if err == nil {
|
||||
for _, key := range session.TokenStack[stackIndex] {
|
||||
delete(session.Tokens, key)
|
||||
}
|
||||
session.TokenStack = session.TokenStack[:len(session.TokenStack)-1]
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ensureConsistency queries database for all given files and its ancestors, make sure there's no modification in
|
||||
// between. This is to make sure there's no modification between navigator's first query and lock acquisition.
|
||||
func (f *DBFS) ensureConsistency(ctx context.Context, files ...*File) error {
|
||||
if len(files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate a list of unique files (include ancestors) to check
|
||||
uniqueFiles := make(map[int]*File)
|
||||
for _, file := range files {
|
||||
for root := file; root != nil; root = root.Parent {
|
||||
if _, ok := uniqueFiles[root.Model.ID]; ok {
|
||||
// This file and its ancestors are already included
|
||||
break
|
||||
}
|
||||
|
||||
uniqueFiles[root.Model.ID] = root
|
||||
}
|
||||
}
|
||||
|
||||
page := 0
|
||||
fileIds := lo.Keys(uniqueFiles)
|
||||
for page >= 0 {
|
||||
files, next, err := f.fileClient.GetByIDs(ctx, fileIds, page)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check file consistency: %w", err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
latest := uniqueFiles[file.ID].Model
|
||||
if file.Name != latest.Name ||
|
||||
file.FileChildren != latest.FileChildren ||
|
||||
file.OwnerID != latest.OwnerID ||
|
||||
file.Type != latest.Type {
|
||||
return fs.ErrModified.
|
||||
WithError(fmt.Errorf("file %s has been modified before lock acquisition", file.Name))
|
||||
}
|
||||
}
|
||||
|
||||
page = next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LockSessionFromCtx retrieves lock session from context. If no lock session
|
||||
// found, a new empty lock session will be returned.
|
||||
func LockSessionFromCtx(ctx context.Context) *LockSession {
|
||||
l, _ := ctx.Value(fs.LockSessionCtxKey{}).(*LockSession)
|
||||
if l == nil {
|
||||
ls := &LockSession{
|
||||
Tokens: make(map[string]string),
|
||||
TokenStack: make([][]string, 0),
|
||||
}
|
||||
|
||||
l = ls
|
||||
}
|
||||
|
||||
l.TokenStack = append(l.TokenStack, make([]string, 0))
|
||||
return l
|
||||
}
|
||||
|
||||
// Exclude removes lock from session, so that it won't be released.
|
||||
func (l *LockSession) Exclude(lock *LockByPath, u *ent.User, hasher hashid.Encoder) string {
|
||||
_, _, lKey := lockTupleFromUri(lock.Uri, u, hasher)
|
||||
foundInCurrentStack := false
|
||||
token, found := l.Tokens[lKey]
|
||||
if found {
|
||||
stackIndex := len(l.TokenStack) - 1
|
||||
l.TokenStack[stackIndex] = lo.Filter(l.TokenStack[stackIndex], func(t string, index int) bool {
|
||||
if t == lKey {
|
||||
foundInCurrentStack = true
|
||||
}
|
||||
return t != lKey
|
||||
})
|
||||
if foundInCurrentStack {
|
||||
delete(l.Tokens, lKey)
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (l *LockSession) LastToken() string {
|
||||
stackIndex := len(l.TokenStack) - 1
|
||||
if len(l.TokenStack[stackIndex]) == 0 {
|
||||
return ""
|
||||
}
|
||||
return l.Tokens[l.TokenStack[stackIndex][len(l.TokenStack[stackIndex])-1]]
|
||||
}
|
||||
|
||||
// WithAlwaysIncludeToken returns a new context with a flag to always include token in conflic response.
|
||||
func WithAlwaysIncludeToken(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, AlwaysIncludeTokenCtx{}, true)
|
||||
}
|
||||
|
||||
func lockTupleFromUri(uri *fs.URI, u *ent.User, hasher hashid.Encoder) (string, string, string) {
|
||||
id := uri.ID(hashid.EncodeUserID(hasher, u.ID))
|
||||
if id == "" {
|
||||
id = strconv.Itoa(u.ID)
|
||||
}
|
||||
ns := fmt.Sprintf(id + "/" + string(uri.FileSystem()))
|
||||
root := uri.Path()
|
||||
return ns, root, ns + "/" + root
|
||||
}
|
||||
831
pkg/filemanager/fs/dbfs/manage.go
Normal file
831
pkg/filemanager/fs/dbfs/manage.go
Normal file
@@ -0,0 +1,831 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/tools/container/intsets"
|
||||
)
|
||||
|
||||
func (f *DBFS) Create(ctx context.Context, path *fs.URI, fileType types.FileType, opts ...fs.Option) (fs.File, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityCreateFile, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get most recent ancestor
|
||||
var ancestor *File
|
||||
if o.ancestor != nil {
|
||||
ancestor = o.ancestor
|
||||
} else {
|
||||
ancestor, err = f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil && !ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("failed to get ancestor: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if ancestor.Uri(false).IsSame(path, hashid.EncodeUserID(f.hasher, f.user.ID)) {
|
||||
if ancestor.Type() == fileType {
|
||||
if o.errOnConflict {
|
||||
return ancestor, fs.ErrFileExisted
|
||||
}
|
||||
|
||||
// Target file already exist, return it.
|
||||
return ancestor, nil
|
||||
}
|
||||
|
||||
// File with the same name but different type already exist
|
||||
return nil, fs.ErrFileExisted.
|
||||
WithError(fmt.Errorf("object with the same name but different type %q already exist", ancestor.Type()))
|
||||
}
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && ancestor.Owner().ID != f.user.ID {
|
||||
return nil, fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
// Lock ancestor
|
||||
lockedPath := ancestor.RootUri().JoinRaw(path.PathTrimmed())
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationCreate),
|
||||
&LockByPath{lockedPath, ancestor, fileType, ""})
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For all ancestors in user's desired path, create folders if not exist
|
||||
existedElements := ancestor.Uri(false).Elements()
|
||||
desired := path.Elements()
|
||||
if (len(desired)-len(existedElements) > 1) && o.noChainedCreation {
|
||||
return nil, fs.ErrPathNotExist
|
||||
}
|
||||
|
||||
for i := len(existedElements); i < len(desired); i++ {
|
||||
// Make sure parent is a folder
|
||||
if !ancestor.CanHaveChildren() {
|
||||
return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("parent must be a valid folder"))
|
||||
}
|
||||
|
||||
// Validate object name
|
||||
if err := validateFileName(desired[i]); err != nil {
|
||||
return nil, fs.ErrIllegalObjectName.WithError(err)
|
||||
}
|
||||
|
||||
if i < len(desired)-1 || fileType == types.FileTypeFolder {
|
||||
args := &inventory.CreateFolderParameters{
|
||||
Owner: ancestor.Model.OwnerID,
|
||||
Name: desired[i],
|
||||
}
|
||||
|
||||
// Apply options for last element
|
||||
if i == len(desired)-1 {
|
||||
if o.Metadata != nil {
|
||||
args.Metadata = o.Metadata
|
||||
}
|
||||
args.IsSymbolic = o.isSymbolicLink
|
||||
}
|
||||
|
||||
// Create folder if it is not the last element or the target is a folder
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
newFolder, err := fc.CreateFolder(ctx, ancestor.Model, args)
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return nil, fmt.Errorf("failed to create folder %q: %w", desired[i], err)
|
||||
}
|
||||
|
||||
if err := inventory.Commit(tx); err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit folder creation", err)
|
||||
}
|
||||
|
||||
ancestor = newFile(ancestor, newFolder)
|
||||
} else {
|
||||
file, err := f.createFile(ctx, ancestor, desired[i], fileType, o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
return ancestor, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) Rename(ctx context.Context, path *fs.URI, newName string) (fs.File, error) {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityRenameFile, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get target file
|
||||
target, err := f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get target file: %w", err)
|
||||
}
|
||||
oldName := target.Name()
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID {
|
||||
return nil, fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
// Root folder cannot be modified
|
||||
if target.IsRootFolder() {
|
||||
return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot modify root folder"))
|
||||
}
|
||||
|
||||
// Validate new name
|
||||
if err := validateFileName(newName); err != nil {
|
||||
return nil, fs.ErrIllegalObjectName.WithError(err)
|
||||
}
|
||||
|
||||
// If target is a file, validate file extension
|
||||
policy, err := f.getPreferredPolicy(ctx, target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if target.Type() == types.FileTypeFile {
|
||||
if err := validateExtension(newName, policy); err != nil {
|
||||
return nil, fs.ErrIllegalObjectName.WithError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Lock target
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationRename),
|
||||
&LockByPath{target.Uri(true), target, target.Type(), ""})
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rename target
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
updated, err := fc.Rename(ctx, target.Model, newName)
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
if ent.IsConstraintError(err) {
|
||||
return nil, fs.ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "failed to update file", err)
|
||||
}
|
||||
|
||||
if target.Type() == types.FileTypeFile && !strings.EqualFold(filepath.Ext(newName), filepath.Ext(oldName)) {
|
||||
if err := fc.RemoveMetadata(ctx, target.Model, ThumbDisabledKey); err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "failed to remove disabled thumbnail mark", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := inventory.Commit(tx); err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit rename change", err)
|
||||
}
|
||||
|
||||
return target.Replace(updated), nil
|
||||
}
|
||||
|
||||
func (f *DBFS) SoftDelete(ctx context.Context, path ...*fs.URI) error {
|
||||
ae := serializer.NewAggregateError()
|
||||
targets := make([]*File, 0, len(path))
|
||||
for _, p := range path {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, p, NavigatorCapabilitySoftDelete)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get target file
|
||||
target, err := f.getFileByPath(ctx, navigator, p)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), fmt.Errorf("failed to get target file: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID {
|
||||
ae.Add(p.String(), fs.ErrOwnerOnly.WithError(fmt.Errorf("only file owner can delete file without trash bin")))
|
||||
continue
|
||||
}
|
||||
|
||||
// Root folder cannot be deleted
|
||||
if target.IsRootFolder() {
|
||||
ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot delete root folder")))
|
||||
continue
|
||||
}
|
||||
|
||||
targets = append(targets, target)
|
||||
}
|
||||
|
||||
if len(targets) == 0 {
|
||||
return ae.Aggregate()
|
||||
}
|
||||
// Lock all targets
|
||||
lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath {
|
||||
return &LockByPath{value.Uri(true), value, value.Type(), ""}
|
||||
})
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationSoftDelete), lockTargets...)
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start transaction to soft-delete files
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
// Perform soft-delete
|
||||
if err := fc.SoftDelete(ctx, target.Model); err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return serializer.NewError(serializer.CodeDBError, "failed to soft-delete file", err)
|
||||
}
|
||||
|
||||
// Save restore uri into metadata
|
||||
if err := fc.UpsertMetadata(ctx, target.Model, map[string]string{
|
||||
MetadataRestoreUri: target.Uri(true).String(),
|
||||
MetadataExpectedCollectTime: strconv.FormatInt(
|
||||
time.Now().Add(time.Duration(target.Owner().Edges.Group.Settings.TrashRetention)*time.Second).Unix(),
|
||||
10),
|
||||
}, nil); err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return serializer.NewError(serializer.CodeDBError, "failed to update metadata", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := inventory.Commit(tx); err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to commit soft-delete change", err)
|
||||
}
|
||||
|
||||
return ae.Aggregate()
|
||||
}
|
||||
|
||||
func (f *DBFS) Delete(ctx context.Context, path []*fs.URI, opts ...fs.Option) ([]fs.Entity, error) {
|
||||
o := newDbfsOption()
|
||||
for _, opt := range opts {
|
||||
o.apply(opt)
|
||||
}
|
||||
|
||||
var opt *types.EntityRecycleOption
|
||||
if o.UnlinkOnly {
|
||||
opt = &types.EntityRecycleOption{
|
||||
UnlinkOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
ae := serializer.NewAggregateError()
|
||||
fileNavGroup := make(map[Navigator][]*File)
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true)
|
||||
|
||||
for _, p := range path {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityDeleteFile, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get target file
|
||||
target, err := f.getFileByPath(ctx, navigator, p)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), fmt.Errorf("failed to get target file: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !o.SysSkipSoftDelete && !ok && target.Owner().ID != f.user.ID {
|
||||
ae.Add(p.String(), fs.ErrOwnerOnly)
|
||||
continue
|
||||
}
|
||||
|
||||
// Root folder cannot be deleted
|
||||
if target.IsRootFolder() {
|
||||
ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot delete root folder")))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := fileNavGroup[navigator]; !ok {
|
||||
fileNavGroup[navigator] = make([]*File, 0)
|
||||
}
|
||||
fileNavGroup[navigator] = append(fileNavGroup[navigator], target)
|
||||
}
|
||||
|
||||
targets := lo.Flatten(lo.Values(fileNavGroup))
|
||||
if len(targets) == 0 {
|
||||
return nil, ae.Aggregate()
|
||||
}
|
||||
// Lock all targets
|
||||
lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath {
|
||||
return &LockByPath{value.Uri(true), value, value.Type(), ""}
|
||||
})
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationDelete), lockTargets...)
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
// Delete targets
|
||||
newStaleEntities, storageDiff, err := f.deleteFiles(ctx, fileNavGroup, fc, opt)
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "failed to delete files", err)
|
||||
}
|
||||
|
||||
tx.AppendStorageDiff(storageDiff)
|
||||
if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to commit delete change", err)
|
||||
}
|
||||
|
||||
return newStaleEntities, ae.Aggregate()
|
||||
}
|
||||
|
||||
func (f *DBFS) VersionControl(ctx context.Context, path *fs.URI, versionId int, delete bool) error {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, path, NavigatorCapabilityVersionControl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get target file
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true)
|
||||
target, err := f.getFileByPath(ctx, navigator, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target file: %w", err)
|
||||
}
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID {
|
||||
return fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
// Target must be a file
|
||||
if target.Type() != types.FileTypeFile {
|
||||
return fs.ErrNotSupportedAction.WithError(fmt.Errorf("target must be a valid file"))
|
||||
}
|
||||
|
||||
// Lock file
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, true, fs.LockApp(fs.ApplicationVersionControl),
|
||||
&LockByPath{target.Uri(true), target, target.Type(), ""})
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if delete {
|
||||
storageDiff, err := f.deleteEntity(ctx, target, versionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f.userClient.ApplyStorageDiff(ctx, storageDiff); err != nil {
|
||||
f.l.Error("Failed to apply storage diff after deleting version: %s", err)
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return f.setCurrentVersion(ctx, target, versionId)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DBFS) Restore(ctx context.Context, path ...*fs.URI) error {
|
||||
ae := serializer.NewAggregateError()
|
||||
targets := make([]*File, 0, len(path))
|
||||
ctx = context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, true)
|
||||
|
||||
for _, p := range path {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityRestore)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get target file
|
||||
target, err := f.getFileByPath(ctx, navigator, p)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), fmt.Errorf("failed to get file: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
targets = append(targets, target)
|
||||
}
|
||||
|
||||
if len(targets) == 0 {
|
||||
return ae.Aggregate()
|
||||
}
|
||||
|
||||
allTrashUriStr := lo.FilterMap(targets, func(t *File, key int) ([]*fs.URI, bool) {
|
||||
if restoreUri, ok := t.Metadata()[MetadataRestoreUri]; ok {
|
||||
srcUrl, err := fs.NewUriFromString(restoreUri)
|
||||
if err != nil {
|
||||
ae.Add(t.Uri(false).String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("invalid restore uri: %w", err)))
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return []*fs.URI{t.Uri(false), srcUrl.DirUri()}, true
|
||||
}
|
||||
|
||||
ae.Add(t.Uri(false).String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot restore file without required metadata mark")))
|
||||
return nil, false
|
||||
})
|
||||
|
||||
// Copy each file to its original location
|
||||
for _, uris := range allTrashUriStr {
|
||||
if err := f.MoveOrCopy(ctx, []*fs.URI{uris[0]}, uris[1], false); err != nil {
|
||||
if !ae.Merge(err) {
|
||||
ae.Add(uris[0].String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ae.Aggregate()
|
||||
|
||||
}
|
||||
|
||||
func (f *DBFS) MoveOrCopy(ctx context.Context, path []*fs.URI, dst *fs.URI, isCopy bool) error {
|
||||
targets := make([]*File, 0, len(path))
|
||||
dstNavigator, err := f.getNavigator(ctx, dst, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get destination file
|
||||
destination, err := f.getFileByPath(ctx, dstNavigator, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("faield to get destination folder: %w", err)
|
||||
}
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && destination.Owner().ID != f.user.ID {
|
||||
return fs.ErrOwnerOnly
|
||||
}
|
||||
|
||||
// Target must be a folder
|
||||
if !destination.CanHaveChildren() {
|
||||
return fs.ErrNotSupportedAction.WithError(fmt.Errorf("destination must be a valid folder"))
|
||||
}
|
||||
|
||||
ae := serializer.NewAggregateError()
|
||||
fileNavGroup := make(map[Navigator][]*File)
|
||||
dstRootPath := destination.Uri(true)
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileEntity{}, true)
|
||||
ctx = context.WithValue(ctx, inventory.LoadFileMetadata{}, true)
|
||||
|
||||
for _, p := range path {
|
||||
// Get navigator
|
||||
navigator, err := f.getNavigator(ctx, p, NavigatorCapabilityLockFile)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check fs capability
|
||||
if !canMoveOrCopyTo(p, dst, isCopy) {
|
||||
ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move or copy file form %s to %s", p.String(), dst.String())))
|
||||
continue
|
||||
}
|
||||
|
||||
// Get target file
|
||||
target, err := f.getFileByPath(ctx, navigator, p)
|
||||
if err != nil {
|
||||
ae.Add(p.String(), fmt.Errorf("failed to get file: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := ctx.Value(ByPassOwnerCheckCtxKey{}).(bool); !ok && target.Owner().ID != f.user.ID {
|
||||
ae.Add(p.String(), fs.ErrOwnerOnly)
|
||||
continue
|
||||
}
|
||||
|
||||
// Root folder cannot be moved or copied
|
||||
if target.IsRootFolder() {
|
||||
ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move root folder")))
|
||||
continue
|
||||
}
|
||||
|
||||
// Cannot move or copy folder to its descendant
|
||||
if target.Type() == types.FileTypeFolder &&
|
||||
dstRootPath.EqualOrIsDescendantOf(target.Uri(true), hashid.EncodeUserID(f.hasher, f.user.ID)) {
|
||||
ae.Add(p.String(), fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot move or copy folder to itself or its descendant")))
|
||||
continue
|
||||
}
|
||||
|
||||
targets = append(targets, target)
|
||||
if isCopy {
|
||||
if _, ok := fileNavGroup[navigator]; !ok {
|
||||
fileNavGroup[navigator] = make([]*File, 0)
|
||||
}
|
||||
fileNavGroup[navigator] = append(fileNavGroup[navigator], target)
|
||||
}
|
||||
}
|
||||
|
||||
if len(targets) > 0 {
|
||||
// Lock all targets
|
||||
lockTargets := lo.Map(targets, func(value *File, key int) *LockByPath {
|
||||
return &LockByPath{value.Uri(true), value, value.Type(), ""}
|
||||
})
|
||||
|
||||
// Lock destination
|
||||
dstBase := destination.Uri(true)
|
||||
dstLockTargets := lo.Map(targets, func(value *File, key int) *LockByPath {
|
||||
return &LockByPath{dstBase.Join(value.Name()), destination, value.Type(), ""}
|
||||
})
|
||||
allLockTargets := make([]*LockByPath, 0, len(targets)*2)
|
||||
if !isCopy {
|
||||
// For moving files from trash bin, also lock the dst with restored name.
|
||||
dstRestoreTargets := lo.FilterMap(targets, func(value *File, key int) (*LockByPath, bool) {
|
||||
if _, ok := value.Metadata()[MetadataRestoreUri]; ok {
|
||||
return &LockByPath{dstBase.Join(value.DisplayName()), destination, value.Type(), ""}, true
|
||||
}
|
||||
return nil, false
|
||||
})
|
||||
allLockTargets = append(allLockTargets, lockTargets...)
|
||||
allLockTargets = append(allLockTargets, dstRestoreTargets...)
|
||||
}
|
||||
allLockTargets = append(allLockTargets, dstLockTargets...)
|
||||
ls, err := f.acquireByPath(ctx, -1, f.user, false, fs.LockApp(fs.ApplicationMoveCopy), allLockTargets...)
|
||||
defer func() { _ = f.Release(ctx, ls) }()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start transaction to move files
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
var (
|
||||
storageDiff inventory.StorageDiff
|
||||
)
|
||||
if isCopy {
|
||||
_, storageDiff, err = f.copyFiles(ctx, fileNavGroup, destination, fc)
|
||||
} else {
|
||||
storageDiff, err = f.moveFiles(ctx, targets, destination, fc, dstNavigator)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
tx.AppendStorageDiff(storageDiff)
|
||||
if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to commit move change", err)
|
||||
}
|
||||
|
||||
// TODO: after move, dbfs cache should be cleared
|
||||
}
|
||||
|
||||
return ae.Aggregate()
|
||||
}
|
||||
|
||||
func (f *DBFS) deleteEntity(ctx context.Context, target *File, entityId int) (inventory.StorageDiff, error) {
|
||||
if target.PrimaryEntityID() == entityId {
|
||||
return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot delete current version"))
|
||||
}
|
||||
|
||||
targetVersion, found := lo.Find(target.Entities(), func(item fs.Entity) bool {
|
||||
return item.ID() == entityId
|
||||
})
|
||||
if !found {
|
||||
return nil, fs.ErrEntityNotExist.WithError(fmt.Errorf("version not found"))
|
||||
}
|
||||
|
||||
diff, err := f.fileClient.UnlinkEntity(ctx, targetVersion.Model(), target.Model, target.Owner())
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to unlink entity", err)
|
||||
}
|
||||
|
||||
if targetVersion.UploadSessionID() != nil {
|
||||
err = f.fileClient.RemoveMetadata(ctx, target.Model, MetadataUploadSessionID)
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to remove upload session metadata", err)
|
||||
}
|
||||
}
|
||||
return diff, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) setCurrentVersion(ctx context.Context, target *File, versionId int) error {
|
||||
if target.PrimaryEntityID() == versionId {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetVersion, found := lo.Find(target.Entities(), func(item fs.Entity) bool {
|
||||
return item.ID() == versionId && item.Type() == types.EntityTypeVersion && item.UploadSessionID() == nil
|
||||
})
|
||||
if !found {
|
||||
return fs.ErrEntityNotExist.WithError(fmt.Errorf("version not found"))
|
||||
}
|
||||
|
||||
fc, tx, ctx, err := inventory.WithTx(ctx, f.fileClient)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to start transaction", err)
|
||||
}
|
||||
|
||||
if err := f.fileClient.SetPrimaryEntity(ctx, target.Model, targetVersion.ID()); err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to set primary entity", err)
|
||||
}
|
||||
|
||||
// Cap thumbnail entities
|
||||
diff, err := fc.CapEntities(ctx, target.Model, target.Owner(), 0, types.EntityTypeThumbnail)
|
||||
if err != nil {
|
||||
_ = inventory.Rollback(tx)
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to cap thumbnail entities", err)
|
||||
}
|
||||
|
||||
tx.AppendStorageDiff(diff)
|
||||
if err := inventory.CommitWithStorageDiff(ctx, tx, f.l, f.userClient); err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to commit set current version", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DBFS) deleteFiles(ctx context.Context, targets map[Navigator][]*File, fc inventory.FileClient, opt *types.EntityRecycleOption) ([]fs.Entity, inventory.StorageDiff, error) {
|
||||
if f.user.Edges.Group == nil {
|
||||
return nil, nil, fmt.Errorf("user group not loaded")
|
||||
}
|
||||
limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1)
|
||||
allStaleEntities := make([]fs.Entity, 0, len(targets))
|
||||
storageDiff := make(inventory.StorageDiff)
|
||||
for n, files := range targets {
|
||||
// Let navigator use tx
|
||||
reset, err := n.FollowTx(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer reset()
|
||||
|
||||
// List all files to be deleted
|
||||
toBeDeletedFiles := make([]*File, 0, len(files))
|
||||
if err := n.Walk(ctx, files, limit, intsets.MaxInt, func(targets []*File, level int) error {
|
||||
limit -= len(targets)
|
||||
toBeDeletedFiles = append(toBeDeletedFiles, targets...)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to walk files: %w", err)
|
||||
}
|
||||
|
||||
// Delete files
|
||||
staleEntities, diff, err := fc.Delete(ctx, lo.Map(toBeDeletedFiles, func(item *File, index int) *ent.File {
|
||||
return item.Model
|
||||
}), opt)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to delete files: %w", err)
|
||||
}
|
||||
storageDiff.Merge(diff)
|
||||
allStaleEntities = append(allStaleEntities, lo.Map(staleEntities, func(item *ent.Entity, index int) fs.Entity {
|
||||
return fs.NewEntity(item)
|
||||
})...)
|
||||
}
|
||||
|
||||
return allStaleEntities, storageDiff, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) copyFiles(ctx context.Context, targets map[Navigator][]*File, destination *File, fc inventory.FileClient) (map[int]*ent.File, inventory.StorageDiff, error) {
|
||||
if f.user.Edges.Group == nil {
|
||||
return nil, nil, fmt.Errorf("user group not loaded")
|
||||
}
|
||||
limit := max(f.user.Edges.Group.Settings.MaxWalkedFiles, 1)
|
||||
capacity, err := f.Capacity(ctx, destination.Owner())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("copy files: failed to destination owner capacity: %w", err)
|
||||
}
|
||||
|
||||
dstAncestors := lo.Map(destination.AncestorsChain(), func(item *File, index int) *ent.File {
|
||||
return item.Model
|
||||
})
|
||||
|
||||
// newTargetsMap is the map of between new target files in first layer, and its src file ID.
|
||||
newTargetsMap := make(map[int]*ent.File)
|
||||
storageDiff := make(inventory.StorageDiff)
|
||||
var diff inventory.StorageDiff
|
||||
for n, files := range targets {
|
||||
initialDstMap := make(map[int][]*ent.File)
|
||||
for _, file := range files {
|
||||
initialDstMap[file.Model.FileChildren] = dstAncestors
|
||||
}
|
||||
|
||||
firstLayer := true
|
||||
// Let navigator use tx
|
||||
reset, err := n.FollowTx(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer reset()
|
||||
|
||||
if err := n.Walk(ctx, files, limit, intsets.MaxInt, func(targets []*File, level int) error {
|
||||
// check capacity for each file
|
||||
sizeTotal := int64(0)
|
||||
for _, file := range targets {
|
||||
sizeTotal += file.SizeUsed()
|
||||
}
|
||||
|
||||
if err := f.validateUserCapacityRaw(ctx, sizeTotal, capacity); err != nil {
|
||||
return fs.ErrInsufficientCapacity
|
||||
}
|
||||
|
||||
limit -= len(targets)
|
||||
initialDstMap, diff, err = fc.Copy(ctx, lo.Map(targets, func(item *File, index int) *ent.File {
|
||||
return item.Model
|
||||
}), initialDstMap)
|
||||
if err != nil {
|
||||
if ent.IsConstraintError(err) {
|
||||
return fs.ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to copy files", err)
|
||||
}
|
||||
|
||||
storageDiff.Merge(diff)
|
||||
|
||||
if firstLayer {
|
||||
for k, v := range initialDstMap {
|
||||
newTargetsMap[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
capacity.Used += sizeTotal
|
||||
firstLayer = false
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to walk files: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newTargetsMap, storageDiff, nil
|
||||
}
|
||||
|
||||
func (f *DBFS) moveFiles(ctx context.Context, targets []*File, destination *File, fc inventory.FileClient, n Navigator) (inventory.StorageDiff, error) {
|
||||
models := lo.Map(targets, func(value *File, key int) *ent.File {
|
||||
return value.Model
|
||||
})
|
||||
|
||||
// Change targets' parent
|
||||
if err := fc.SetParent(ctx, models, destination.Model); err != nil {
|
||||
if ent.IsConstraintError(err) {
|
||||
return nil, fs.ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to move file", err)
|
||||
}
|
||||
|
||||
var (
|
||||
storageDiff inventory.StorageDiff
|
||||
)
|
||||
|
||||
// For files moved out from trash bin
|
||||
for _, file := range targets {
|
||||
if _, ok := file.Metadata()[MetadataRestoreUri]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// renaming it to its original name
|
||||
if _, err := fc.Rename(ctx, file.Model, file.DisplayName()); err != nil {
|
||||
if ent.IsConstraintError(err) {
|
||||
return nil, fs.ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
return storageDiff, serializer.NewError(serializer.CodeDBError, "Failed to rename file from trash bin to its original name", err)
|
||||
}
|
||||
|
||||
// Remove trash bin metadata
|
||||
if err := fc.RemoveMetadata(ctx, file.Model, MetadataRestoreUri, MetadataExpectedCollectTime); err != nil {
|
||||
return storageDiff, serializer.NewError(serializer.CodeDBError, "Failed to remove trash related metadata", err)
|
||||
}
|
||||
}
|
||||
|
||||
return storageDiff, nil
|
||||
}
|
||||
172
pkg/filemanager/fs/dbfs/my_navigator.go
Normal file
172
pkg/filemanager/fs/dbfs/my_navigator.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
)
|
||||
|
||||
var myNavigatorCapability = &boolset.BooleanSet{}
|
||||
|
||||
// NewMyNavigator creates a navigator for user's "my" file system.
|
||||
func NewMyNavigator(u *ent.User, fileClient inventory.FileClient, userClient inventory.UserClient, l logging.Logger,
|
||||
config *setting.DBFS, hasher hashid.Encoder) Navigator {
|
||||
return &myNavigator{
|
||||
user: u,
|
||||
l: l,
|
||||
fileClient: fileClient,
|
||||
userClient: userClient,
|
||||
config: config,
|
||||
baseNavigator: newBaseNavigator(fileClient, defaultFilter, u, hasher, config),
|
||||
}
|
||||
}
|
||||
|
||||
type myNavigator struct {
|
||||
l logging.Logger
|
||||
user *ent.User
|
||||
fileClient inventory.FileClient
|
||||
userClient inventory.UserClient
|
||||
|
||||
config *setting.DBFS
|
||||
*baseNavigator
|
||||
root *File
|
||||
disableRecycle bool
|
||||
persist func()
|
||||
}
|
||||
|
||||
func (n *myNavigator) Recycle() {
|
||||
if n.persist != nil {
|
||||
n.persist()
|
||||
n.persist = nil
|
||||
}
|
||||
if n.root != nil && !n.disableRecycle {
|
||||
n.root.Recycle()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *myNavigator) PersistState(kv cache.Driver, key string) {
|
||||
n.disableRecycle = true
|
||||
n.persist = func() {
|
||||
kv.Set(key, n.root, ContextHintTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *myNavigator) RestoreState(s State) error {
|
||||
n.disableRecycle = true
|
||||
if state, ok := s.(*File); ok {
|
||||
n.root = state
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid state type: %T", s)
|
||||
}
|
||||
|
||||
func (n *myNavigator) To(ctx context.Context, path *fs.URI) (*File, error) {
|
||||
if n.root == nil {
|
||||
// Anonymous user does not have a root folder.
|
||||
if inventory.IsAnonymousUser(n.user) {
|
||||
return nil, ErrLoginRequired
|
||||
}
|
||||
|
||||
fsUid, err := n.hasher.Decode(path.ID(hashid.EncodeUserID(n.hasher, n.user.ID)), hashid.UserID)
|
||||
if err != nil {
|
||||
return nil, fs.ErrPathNotExist.WithError(fmt.Errorf("invalid user id"))
|
||||
}
|
||||
if fsUid != n.user.ID {
|
||||
return nil, ErrPermissionDenied
|
||||
}
|
||||
|
||||
targetUser, err := n.userClient.GetLoginUserByID(ctx, fsUid)
|
||||
if err != nil {
|
||||
return nil, fs.ErrPathNotExist.WithError(fmt.Errorf("user not found: %w", err))
|
||||
}
|
||||
|
||||
rootFile, err := n.fileClient.Root(ctx, targetUser)
|
||||
if err != nil {
|
||||
n.l.Info("User's root folder not found: %s, will initialize it.", err)
|
||||
return nil, ErrFsNotInitialized
|
||||
}
|
||||
|
||||
n.root = newFile(nil, rootFile)
|
||||
rootPath := path.Root()
|
||||
n.root.Path[pathIndexRoot], n.root.Path[pathIndexUser] = rootPath, rootPath
|
||||
n.root.OwnerModel = targetUser
|
||||
n.root.IsUserRoot = true
|
||||
n.root.CapabilitiesBs = n.Capabilities(false).Capability
|
||||
}
|
||||
|
||||
current, lastAncestor := n.root, n.root
|
||||
elements := path.Elements()
|
||||
var err error
|
||||
for index, element := range elements {
|
||||
lastAncestor = current
|
||||
current, err = n.walkNext(ctx, current, element, index == len(elements)-1)
|
||||
if err != nil {
|
||||
return lastAncestor, fmt.Errorf("failed to walk into %q: %w", element, err)
|
||||
}
|
||||
}
|
||||
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (n *myNavigator) Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) {
|
||||
return n.baseNavigator.children(ctx, parent, args)
|
||||
}
|
||||
|
||||
func (n *myNavigator) walkNext(ctx context.Context, root *File, next string, isLeaf bool) (*File, error) {
|
||||
return n.baseNavigator.walkNext(ctx, root, next, isLeaf)
|
||||
}
|
||||
|
||||
func (n *myNavigator) Capabilities(isSearching bool) *fs.NavigatorProps {
|
||||
res := &fs.NavigatorProps{
|
||||
Capability: myNavigatorCapability,
|
||||
OrderDirectionOptions: fullOrderDirectionOption,
|
||||
OrderByOptions: fullOrderByOption,
|
||||
MaxPageSize: n.config.MaxPageSize,
|
||||
}
|
||||
if isSearching {
|
||||
res.OrderByOptions = nil
|
||||
res.OrderDirectionOptions = nil
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (n *myNavigator) Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error {
|
||||
return n.baseNavigator.walk(ctx, levelFiles, limit, depth, f)
|
||||
}
|
||||
|
||||
func (n *myNavigator) FollowTx(ctx context.Context) (func(), error) {
|
||||
if _, ok := ctx.Value(inventory.TxCtx{}).(*inventory.Tx); !ok {
|
||||
return nil, fmt.Errorf("navigator: no inherited transaction found in context")
|
||||
}
|
||||
newFileClient, _, _, err := inventory.WithTx(ctx, n.fileClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newUserClient, _, _, err := inventory.WithTx(ctx, n.userClient)
|
||||
|
||||
oldFileClient, oldUserClient := n.fileClient, n.userClient
|
||||
revert := func() {
|
||||
n.fileClient = oldFileClient
|
||||
n.userClient = oldUserClient
|
||||
n.baseNavigator.fileClient = oldFileClient
|
||||
}
|
||||
|
||||
n.fileClient = newFileClient
|
||||
n.userClient = newUserClient
|
||||
n.baseNavigator.fileClient = newFileClient
|
||||
return revert, nil
|
||||
}
|
||||
|
||||
func (n *myNavigator) ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error {
|
||||
return nil
|
||||
}
|
||||
536
pkg/filemanager/fs/dbfs/navigator.go
Normal file
536
pkg/filemanager/fs/dbfs/navigator.go
Normal file
@@ -0,0 +1,536 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFsNotInitialized = fmt.Errorf("fs not initialized")
|
||||
ErrPermissionDenied = serializer.NewError(serializer.CodeNoPermissionErr, "Permission denied", nil)
|
||||
|
||||
ErrShareIncorrectPassword = serializer.NewError(serializer.CodeIncorrectPassword, "Incorrect share password", nil)
|
||||
ErrFileCountLimitedReached = serializer.NewError(serializer.CodeFileCountLimitedReached, "Walked file count reached limit", nil)
|
||||
ErrSymbolicFolderFound = serializer.NewError(serializer.CodeNoPermissionErr, "Symbolic folder cannot be walked into", nil)
|
||||
ErrLoginRequired = serializer.NewError(serializer.CodeCheckLogin, "Login required", nil)
|
||||
|
||||
fullOrderByOption = []string{"name", "size", "updated_at", "created_at"}
|
||||
searchLimitedOrderByOption = []string{"created_at"}
|
||||
fullOrderDirectionOption = []string{"asc", "desc"}
|
||||
)
|
||||
|
||||
type (
|
||||
// Navigator is a navigator for database file system.
|
||||
Navigator interface {
|
||||
Recycle()
|
||||
// To returns the file by path. If given path is not exist, returns ErrFileNotFound and most-recent ancestor.
|
||||
To(ctx context.Context, path *fs.URI) (*File, error)
|
||||
// Children returns the children of the parent file.
|
||||
Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error)
|
||||
// Capabilities returns the capabilities of the navigator.
|
||||
Capabilities(isSearching bool) *fs.NavigatorProps
|
||||
// Walk walks the file tree until limit is reached.
|
||||
Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error
|
||||
// PersistState tells navigator to persist the state of the navigator before recycle.
|
||||
PersistState(kv cache.Driver, key string)
|
||||
// RestoreState restores the state of the navigator.
|
||||
RestoreState(s State) error
|
||||
// FollowTx let the navigator inherit the transaction. Return a function to reset back to previous DB client.
|
||||
FollowTx(ctx context.Context) (func(), error)
|
||||
// ExecuteHook performs custom operations before or after certain actions.
|
||||
ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error
|
||||
}
|
||||
|
||||
State interface{}
|
||||
|
||||
NavigatorCapability int
|
||||
ListArgs struct {
|
||||
Page *inventory.PaginationArgs
|
||||
Search *inventory.SearchFileParameters
|
||||
SharedWithMe bool
|
||||
StreamCallback func([]*File)
|
||||
}
|
||||
// ListResult is the result of a list operation.
|
||||
ListResult struct {
|
||||
Files []*File
|
||||
MixedType bool
|
||||
Pagination *inventory.PaginationResults
|
||||
RecursionLimitReached bool
|
||||
SingleFileView bool
|
||||
}
|
||||
WalkFunc func([]*File, int) error
|
||||
)
|
||||
|
||||
const (
|
||||
NavigatorCapabilityCreateFile NavigatorCapability = iota
|
||||
NavigatorCapabilityRenameFile
|
||||
NavigatorCapability_CommunityPlacehodler1
|
||||
NavigatorCapability_CommunityPlacehodler2
|
||||
NavigatorCapability_CommunityPlacehodler3
|
||||
NavigatorCapability_CommunityPlacehodler4
|
||||
NavigatorCapabilityUploadFile
|
||||
NavigatorCapabilityDownloadFile
|
||||
NavigatorCapabilityUpdateMetadata
|
||||
NavigatorCapabilityListChildren
|
||||
NavigatorCapabilityGenerateThumb
|
||||
NavigatorCapability_CommunityPlacehodler5
|
||||
NavigatorCapability_CommunityPlacehodler6
|
||||
NavigatorCapability_CommunityPlacehodler7
|
||||
NavigatorCapabilityDeleteFile
|
||||
NavigatorCapabilityLockFile
|
||||
NavigatorCapabilitySoftDelete
|
||||
NavigatorCapabilityRestore
|
||||
NavigatorCapabilityShare
|
||||
NavigatorCapabilityInfo
|
||||
NavigatorCapabilityVersionControl
|
||||
NavigatorCapability_CommunityPlacehodler8
|
||||
NavigatorCapability_CommunityPlacehodler9
|
||||
NavigatorCapabilityEnterFolder
|
||||
|
||||
searchTokenSeparator = "|"
|
||||
)
|
||||
|
||||
func init() {
|
||||
boolset.Sets(map[NavigatorCapability]bool{
|
||||
NavigatorCapabilityCreateFile: true,
|
||||
NavigatorCapabilityRenameFile: true,
|
||||
NavigatorCapabilityUploadFile: true,
|
||||
NavigatorCapabilityDownloadFile: true,
|
||||
NavigatorCapabilityUpdateMetadata: true,
|
||||
NavigatorCapabilityListChildren: true,
|
||||
NavigatorCapabilityGenerateThumb: true,
|
||||
NavigatorCapabilityDeleteFile: true,
|
||||
NavigatorCapabilityLockFile: true,
|
||||
NavigatorCapabilitySoftDelete: true,
|
||||
NavigatorCapabilityShare: true,
|
||||
NavigatorCapabilityInfo: true,
|
||||
NavigatorCapabilityVersionControl: true,
|
||||
NavigatorCapabilityEnterFolder: true,
|
||||
}, myNavigatorCapability)
|
||||
boolset.Sets(map[NavigatorCapability]bool{
|
||||
NavigatorCapabilityDownloadFile: true,
|
||||
NavigatorCapabilityListChildren: true,
|
||||
NavigatorCapabilityGenerateThumb: true,
|
||||
NavigatorCapabilityLockFile: true,
|
||||
NavigatorCapabilityInfo: true,
|
||||
NavigatorCapabilityVersionControl: true,
|
||||
NavigatorCapabilityEnterFolder: true,
|
||||
}, shareNavigatorCapability)
|
||||
boolset.Sets(map[NavigatorCapability]bool{
|
||||
NavigatorCapabilityListChildren: true,
|
||||
NavigatorCapabilityDeleteFile: true,
|
||||
NavigatorCapabilityLockFile: true,
|
||||
NavigatorCapabilityRestore: true,
|
||||
NavigatorCapabilityInfo: true,
|
||||
}, trashNavigatorCapability)
|
||||
boolset.Sets(map[NavigatorCapability]bool{
|
||||
NavigatorCapabilityListChildren: true,
|
||||
NavigatorCapabilityDownloadFile: true,
|
||||
NavigatorCapabilityEnterFolder: true,
|
||||
}, sharedWithMeNavigatorCapability)
|
||||
}
|
||||
|
||||
// ==================== Base Navigator ====================
|
||||
type (
|
||||
fileFilter func(ctx context.Context, f *File) (*File, bool)
|
||||
baseNavigator struct {
|
||||
fileClient inventory.FileClient
|
||||
listFilter fileFilter
|
||||
user *ent.User
|
||||
hasher hashid.Encoder
|
||||
config *setting.DBFS
|
||||
}
|
||||
)
|
||||
|
||||
var defaultFilter = func(ctx context.Context, f *File) (*File, bool) { return f, true }
|
||||
|
||||
func newBaseNavigator(fileClient inventory.FileClient, filterFunc fileFilter, user *ent.User,
|
||||
hasher hashid.Encoder, config *setting.DBFS) *baseNavigator {
|
||||
return &baseNavigator{
|
||||
fileClient: fileClient,
|
||||
listFilter: filterFunc,
|
||||
user: user,
|
||||
hasher: hasher,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *baseNavigator) walkNext(ctx context.Context, root *File, next string, isLeaf bool) (*File, error) {
|
||||
var model *ent.File
|
||||
if root != nil {
|
||||
model = root.Model
|
||||
if root.IsSymbolic() {
|
||||
return nil, ErrSymbolicFolderFound
|
||||
}
|
||||
|
||||
root.mu.Lock()
|
||||
if child, ok := root.Children[next]; ok && !isLeaf {
|
||||
root.mu.Unlock()
|
||||
return child, nil
|
||||
}
|
||||
root.mu.Unlock()
|
||||
}
|
||||
|
||||
child, err := b.fileClient.GetChildFile(ctx, model, b.user.ID, next, isLeaf)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fs.ErrPathNotExist.WithError(err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("faield to get child %q: %w", next, err)
|
||||
}
|
||||
|
||||
return newFile(root, child), nil
|
||||
}
|
||||
|
||||
func (b *baseNavigator) walkUp(ctx context.Context, child *File) (*File, error) {
|
||||
parent, err := b.fileClient.GetParentFile(ctx, child.Model, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faield to get Parent for %q: %w", child.Name(), err)
|
||||
}
|
||||
|
||||
return newParentFile(parent, child), nil
|
||||
}
|
||||
|
||||
func (b *baseNavigator) children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) {
|
||||
var model *ent.File
|
||||
if parent != nil {
|
||||
model = parent.Model
|
||||
if parent.Model.Type != int(types.FileTypeFolder) {
|
||||
return nil, fs.ErrPathNotExist
|
||||
}
|
||||
|
||||
if parent.IsSymbolic() {
|
||||
return nil, ErrSymbolicFolderFound
|
||||
}
|
||||
|
||||
parent.Path[pathIndexUser] = parent.Uri(false)
|
||||
}
|
||||
|
||||
if args.Search != nil {
|
||||
return b.search(ctx, parent, args)
|
||||
}
|
||||
|
||||
children, err := b.fileClient.GetChildFiles(ctx, &inventory.ListFileParameters{
|
||||
PaginationArgs: args.Page,
|
||||
SharedWithMe: args.SharedWithMe,
|
||||
}, b.user.ID, model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get children: %w", err)
|
||||
}
|
||||
|
||||
return &ListResult{
|
||||
Files: lo.FilterMap(children.Files, func(model *ent.File, index int) (*File, bool) {
|
||||
f := newFile(parent, model)
|
||||
return b.listFilter(ctx, f)
|
||||
}),
|
||||
MixedType: children.MixedType,
|
||||
Pagination: children.PaginationResults,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *baseNavigator) walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error {
|
||||
walked := 0
|
||||
if len(levelFiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
owner := levelFiles[0].Owner()
|
||||
|
||||
level := 0
|
||||
for walked <= limit && depth >= 0 {
|
||||
if len(levelFiles) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
stop := false
|
||||
depth--
|
||||
if len(levelFiles) > limit-walked {
|
||||
levelFiles = levelFiles[:limit-walked]
|
||||
stop = true
|
||||
}
|
||||
if err := f(levelFiles, level); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stop {
|
||||
return ErrFileCountLimitedReached
|
||||
}
|
||||
|
||||
walked += len(levelFiles)
|
||||
folders := lo.Filter(levelFiles, func(f *File, index int) bool {
|
||||
return f.Model.Type == int(types.FileTypeFolder) && !f.IsSymbolic()
|
||||
})
|
||||
|
||||
if walked >= limit || len(folders) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
levelFiles = levelFiles[:0]
|
||||
leftCredit := limit - walked
|
||||
parents := lo.SliceToMap(folders, func(file *File) (int, *File) {
|
||||
return file.Model.ID, file
|
||||
})
|
||||
for leftCredit > 0 {
|
||||
token := ""
|
||||
res, err := b.fileClient.GetChildFiles(ctx,
|
||||
&inventory.ListFileParameters{
|
||||
PaginationArgs: &inventory.PaginationArgs{
|
||||
UseCursorPagination: true,
|
||||
PageToken: token,
|
||||
PageSize: leftCredit,
|
||||
},
|
||||
MixedType: true,
|
||||
},
|
||||
owner.ID,
|
||||
lo.Map(folders, func(item *File, index int) *ent.File {
|
||||
return item.Model
|
||||
})...)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "Failed to list children", err)
|
||||
}
|
||||
|
||||
leftCredit -= len(res.Files)
|
||||
|
||||
levelFiles = append(levelFiles, lo.Map(res.Files, func(model *ent.File, index int) *File {
|
||||
p := parents[model.FileChildren]
|
||||
return newFile(p, model)
|
||||
})...)
|
||||
|
||||
// All files listed
|
||||
if res.NextPageToken == "" {
|
||||
break
|
||||
}
|
||||
|
||||
token = res.NextPageToken
|
||||
}
|
||||
level++
|
||||
}
|
||||
|
||||
if walked >= limit {
|
||||
return ErrFileCountLimitedReached
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *baseNavigator) search(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) {
|
||||
if parent == nil {
|
||||
// Performs mega search for all files in trash fs.
|
||||
children, err := b.fileClient.GetChildFiles(ctx, &inventory.ListFileParameters{
|
||||
PaginationArgs: args.Page,
|
||||
MixedType: true,
|
||||
Search: args.Search,
|
||||
SharedWithMe: args.SharedWithMe,
|
||||
}, b.user.ID, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get children: %w", err)
|
||||
}
|
||||
|
||||
return &ListResult{
|
||||
Files: lo.FilterMap(children.Files, func(model *ent.File, index int) (*File, bool) {
|
||||
f := newFile(parent, model)
|
||||
return b.listFilter(ctx, f)
|
||||
}),
|
||||
MixedType: children.MixedType,
|
||||
Pagination: children.PaginationResults,
|
||||
}, nil
|
||||
}
|
||||
// Performs recursive search for all files under the given folder.
|
||||
walkedFolder := 1
|
||||
parents := []map[int]*File{{parent.Model.ID: parent}}
|
||||
startLevel, innerPageToken, err := parseSearchPageToken(args.Page.PageToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args.Page.PageToken = innerPageToken
|
||||
|
||||
stepLevel := func(level int) (bool, error) {
|
||||
token := ""
|
||||
// We don't need metadata in level search.
|
||||
listCtx := context.WithValue(ctx, inventory.LoadFilePublicMetadata{}, nil)
|
||||
for walkedFolder <= b.config.MaxRecursiveSearchedFolder {
|
||||
// TODO: chunk parents into 30000 per group
|
||||
res, err := b.fileClient.GetChildFiles(listCtx,
|
||||
&inventory.ListFileParameters{
|
||||
PaginationArgs: &inventory.PaginationArgs{
|
||||
UseCursorPagination: true,
|
||||
PageToken: token,
|
||||
},
|
||||
FolderOnly: true,
|
||||
},
|
||||
parent.Model.OwnerID,
|
||||
lo.MapToSlice(parents[level], func(k int, f *File) *ent.File {
|
||||
return f.Model
|
||||
})...)
|
||||
if err != nil {
|
||||
return false, serializer.NewError(serializer.CodeDBError, "Failed to list children", err)
|
||||
}
|
||||
|
||||
parents = append(parents, lo.SliceToMap(
|
||||
lo.FilterMap(res.Files, func(model *ent.File, index int) (*File, bool) {
|
||||
p := parents[level][model.FileChildren]
|
||||
f := newFile(p, model)
|
||||
f.Path[pathIndexUser] = p.Uri(false).Join(model.Name)
|
||||
return f, true
|
||||
}),
|
||||
func(f *File) (int, *File) {
|
||||
return f.Model.ID, f
|
||||
}))
|
||||
|
||||
walkedFolder += len(parents[level+1])
|
||||
if res.NextPageToken == "" {
|
||||
break
|
||||
}
|
||||
|
||||
token = res.NextPageToken
|
||||
}
|
||||
|
||||
if len(parents) <= level+1 || len(parents[level+1]) == 0 {
|
||||
// All possible folders is searched
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// We need to walk from root folder to get the correct level.
|
||||
for level := 0; level < startLevel; level++ {
|
||||
stop, err := stepLevel(level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stop {
|
||||
return &ListResult{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Search files starting from current level
|
||||
res := make([]*File, 0, args.Page.PageSize)
|
||||
args.Page.UseCursorPagination = true
|
||||
originalPageSize := args.Page.PageSize
|
||||
stop := false
|
||||
for len(res) < originalPageSize && walkedFolder <= b.config.MaxRecursiveSearchedFolder {
|
||||
// Only requires minimum number of files
|
||||
args.Page.PageSize = min(originalPageSize, originalPageSize-len(res))
|
||||
searchRes, err := b.fileClient.GetChildFiles(ctx,
|
||||
&inventory.ListFileParameters{
|
||||
PaginationArgs: args.Page,
|
||||
MixedType: true,
|
||||
Search: args.Search,
|
||||
},
|
||||
parent.Model.OwnerID,
|
||||
lo.MapToSlice(parents[startLevel], func(k int, f *File) *ent.File {
|
||||
return f.Model
|
||||
})...)
|
||||
|
||||
if err != nil {
|
||||
return nil, serializer.NewError(serializer.CodeDBError, "Failed to search files", err)
|
||||
}
|
||||
|
||||
newRes := lo.FilterMap(searchRes.Files, func(model *ent.File, index int) (*File, bool) {
|
||||
p := parents[startLevel][model.FileChildren]
|
||||
f := newFile(p, model)
|
||||
f.Path[pathIndexUser] = p.Uri(false).Join(model.Name)
|
||||
return b.listFilter(ctx, f)
|
||||
})
|
||||
res = append(res, newRes...)
|
||||
if args.StreamCallback != nil {
|
||||
args.StreamCallback(newRes)
|
||||
}
|
||||
|
||||
args.Page.PageToken = searchRes.NextPageToken
|
||||
// If no more results under current level, move to next level
|
||||
if args.Page.PageToken == "" {
|
||||
if len(res) == originalPageSize {
|
||||
// Current page is full, no need to search more
|
||||
startLevel++
|
||||
break
|
||||
}
|
||||
|
||||
finished, err := stepLevel(startLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if finished {
|
||||
stop = true
|
||||
// No more folders under next level, all result is presented
|
||||
break
|
||||
}
|
||||
|
||||
startLevel++
|
||||
}
|
||||
}
|
||||
|
||||
if args.StreamCallback != nil {
|
||||
// Clear res if it's streamed
|
||||
res = res[:0]
|
||||
}
|
||||
|
||||
searchRes := &ListResult{
|
||||
Files: res,
|
||||
MixedType: true,
|
||||
Pagination: &inventory.PaginationResults{IsCursor: true},
|
||||
RecursionLimitReached: walkedFolder > b.config.MaxRecursiveSearchedFolder,
|
||||
}
|
||||
|
||||
if walkedFolder <= b.config.MaxRecursiveSearchedFolder && !stop {
|
||||
searchRes.Pagination.NextPageToken = fmt.Sprintf("%d%s%s", startLevel, searchTokenSeparator, args.Page.PageToken)
|
||||
}
|
||||
|
||||
return searchRes, nil
|
||||
}
|
||||
|
||||
func parseSearchPageToken(token string) (int, string, error) {
|
||||
if token == "" {
|
||||
return 0, "", nil
|
||||
}
|
||||
|
||||
tokens := strings.Split(token, searchTokenSeparator)
|
||||
if len(tokens) != 2 {
|
||||
return 0, "", fmt.Errorf("invalid page token")
|
||||
}
|
||||
|
||||
level, err := strconv.Atoi(tokens[0])
|
||||
if err != nil || level < 0 {
|
||||
return 0, "", fmt.Errorf("invalid page token level")
|
||||
}
|
||||
|
||||
return level, tokens[1], nil
|
||||
}
|
||||
|
||||
func newMyUri() *fs.URI {
|
||||
res, _ := fs.NewUriFromString(constants.CloudreveScheme + "://" + string(constants.FileSystemMy))
|
||||
return res
|
||||
}
|
||||
|
||||
func newMyIDUri(uid string) *fs.URI {
|
||||
res, _ := fs.NewUriFromString(fmt.Sprintf("%s://%s@%s", constants.CloudreveScheme, uid, constants.FileSystemMy))
|
||||
return res
|
||||
}
|
||||
|
||||
func newTrashUri(name string) *fs.URI {
|
||||
res, _ := fs.NewUriFromString(fmt.Sprintf("%s://%s", constants.CloudreveScheme, constants.FileSystemTrash))
|
||||
return res.Join(name)
|
||||
}
|
||||
|
||||
func newSharedWithMeUri(id string) *fs.URI {
|
||||
res, _ := fs.NewUriFromString(fmt.Sprintf("%s://%s", constants.CloudreveScheme, constants.FileSystemSharedWithMe))
|
||||
return res.Join(id)
|
||||
}
|
||||
171
pkg/filemanager/fs/dbfs/options.go
Normal file
171
pkg/filemanager/fs/dbfs/options.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
)
|
||||
|
||||
type dbfsOption struct {
|
||||
*fs.FsOption
|
||||
loadFolderSummary bool
|
||||
extendedInfo bool
|
||||
loadFilePublicMetadata bool
|
||||
loadFileShareIfOwned bool
|
||||
loadEntityUser bool
|
||||
loadFileEntities bool
|
||||
useCursorPagination bool
|
||||
pageToken string
|
||||
preferredStoragePolicy *ent.StoragePolicy
|
||||
errOnConflict bool
|
||||
previousVersion string
|
||||
removeStaleEntities bool
|
||||
requiredCapabilities []NavigatorCapability
|
||||
generateContextHint bool
|
||||
isSymbolicLink bool
|
||||
noChainedCreation bool
|
||||
streamListResponseCallback func(parent fs.File, file []fs.File)
|
||||
ancestor *File
|
||||
}
|
||||
|
||||
func newDbfsOption() *dbfsOption {
|
||||
return &dbfsOption{
|
||||
FsOption: &fs.FsOption{},
|
||||
}
|
||||
}
|
||||
|
||||
func (o *dbfsOption) apply(opt fs.Option) {
|
||||
if fsOpt, ok := opt.(fs.OptionFunc); ok {
|
||||
fsOpt.Apply(o.FsOption)
|
||||
} else if dbfsOpt, ok := opt.(optionFunc); ok {
|
||||
dbfsOpt.Apply(o)
|
||||
}
|
||||
}
|
||||
|
||||
type optionFunc func(*dbfsOption)
|
||||
|
||||
func (f optionFunc) Apply(o any) {
|
||||
if dbfsO, ok := o.(*dbfsOption); ok {
|
||||
f(dbfsO)
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilePublicMetadata enables loading file public metadata.
|
||||
func WithFilePublicMetadata() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.loadFilePublicMetadata = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithContextHint enables generating context hint for the list operation.
|
||||
func WithContextHint() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.generateContextHint = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithFileEntities enables loading file entities.
|
||||
func WithFileEntities() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.loadFileEntities = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithCursorPagination enables cursor pagination for the list operation.
|
||||
func WithCursorPagination(pageToken string) fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.useCursorPagination = true
|
||||
o.pageToken = pageToken
|
||||
})
|
||||
}
|
||||
|
||||
// WithPreferredStoragePolicy sets the preferred storage policy for the upload operation.
|
||||
func WithPreferredStoragePolicy(policy *ent.StoragePolicy) fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.preferredStoragePolicy = policy
|
||||
})
|
||||
}
|
||||
|
||||
// WithErrorOnConflict sets to throw error on conflict for the create operation.
|
||||
func WithErrorOnConflict() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.errOnConflict = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithPreviousVersion sets the previous version for the update operation.
|
||||
func WithPreviousVersion(version string) fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.previousVersion = version
|
||||
})
|
||||
}
|
||||
|
||||
// WithRemoveStaleEntities sets to remove stale entities for the update operation.
|
||||
func WithRemoveStaleEntities() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.removeStaleEntities = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithRequiredCapabilities sets the required capabilities for operations.
|
||||
func WithRequiredCapabilities(capabilities ...NavigatorCapability) fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.requiredCapabilities = capabilities
|
||||
})
|
||||
}
|
||||
|
||||
// WithNoChainedCreation sets to disable chained creation for the create operation. This
|
||||
// will require parent folder existed before creating new files under it.
|
||||
func WithNoChainedCreation() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.noChainedCreation = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithFileShareIfOwned enables loading file share link if the file is owned by the user.
|
||||
func WithFileShareIfOwned() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.loadFileShareIfOwned = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithStreamListResponseCallback sets the callback for handling stream list response.
|
||||
func WithStreamListResponseCallback(callback func(parent fs.File, file []fs.File)) fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.streamListResponseCallback = callback
|
||||
})
|
||||
}
|
||||
|
||||
// WithSymbolicLink sets the file is a symbolic link.
|
||||
func WithSymbolicLink() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.isSymbolicLink = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithExtendedInfo enables loading extended info for the file.
|
||||
func WithExtendedInfo() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.extendedInfo = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithLoadFolderSummary enables loading folder summary.
|
||||
func WithLoadFolderSummary() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.loadFolderSummary = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithEntityUser enables loading entity user.
|
||||
func WithEntityUser() fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.loadEntityUser = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithAncestor sets most recent ancestor for creating files
|
||||
func WithAncestor(f *File) fs.Option {
|
||||
return optionFunc(func(o *dbfsOption) {
|
||||
o.ancestor = f
|
||||
})
|
||||
}
|
||||
324
pkg/filemanager/fs/dbfs/share_navigator.go
Normal file
324
pkg/filemanager/fs/dbfs/share_navigator.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package dbfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrShareNotFound = serializer.NewError(serializer.CodeNotFound, "Shared file does not exist", nil)
|
||||
ErrNotPurchased = serializer.NewError(serializer.CodePurchaseRequired, "You need to purchased this share", nil)
|
||||
)
|
||||
|
||||
const (
|
||||
PurchaseTicketHeader = constants.CrHeaderPrefix + "Purchase-Ticket"
|
||||
)
|
||||
|
||||
var shareNavigatorCapability = &boolset.BooleanSet{}
|
||||
|
||||
// NewShareNavigator creates a navigator for user's "shared" file system.
|
||||
func NewShareNavigator(u *ent.User, fileClient inventory.FileClient, shareClient inventory.ShareClient,
|
||||
l logging.Logger, config *setting.DBFS, hasher hashid.Encoder) Navigator {
|
||||
n := &shareNavigator{
|
||||
user: u,
|
||||
l: l,
|
||||
fileClient: fileClient,
|
||||
shareClient: shareClient,
|
||||
config: config,
|
||||
}
|
||||
n.baseNavigator = newBaseNavigator(fileClient, defaultFilter, u, hasher, config)
|
||||
return n
|
||||
}
|
||||
|
||||
type (
|
||||
shareNavigator struct {
|
||||
l logging.Logger
|
||||
user *ent.User
|
||||
fileClient inventory.FileClient
|
||||
shareClient inventory.ShareClient
|
||||
config *setting.DBFS
|
||||
|
||||
*baseNavigator
|
||||
shareRoot *File
|
||||
singleFileShare bool
|
||||
ownerRoot *File
|
||||
share *ent.Share
|
||||
owner *ent.User
|
||||
disableRecycle bool
|
||||
persist func()
|
||||
}
|
||||
|
||||
shareNavigatorState struct {
|
||||
ShareRoot *File
|
||||
OwnerRoot *File
|
||||
SingleFileShare bool
|
||||
Share *ent.Share
|
||||
Owner *ent.User
|
||||
}
|
||||
)
|
||||
|
||||
func (n *shareNavigator) PersistState(kv cache.Driver, key string) {
|
||||
n.disableRecycle = true
|
||||
n.persist = func() {
|
||||
kv.Set(key, shareNavigatorState{
|
||||
ShareRoot: n.shareRoot,
|
||||
OwnerRoot: n.ownerRoot,
|
||||
SingleFileShare: n.singleFileShare,
|
||||
Share: n.share,
|
||||
Owner: n.owner,
|
||||
}, ContextHintTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *shareNavigator) RestoreState(s State) error {
|
||||
n.disableRecycle = true
|
||||
if state, ok := s.(shareNavigatorState); ok {
|
||||
n.shareRoot = state.ShareRoot
|
||||
n.ownerRoot = state.OwnerRoot
|
||||
n.singleFileShare = state.SingleFileShare
|
||||
n.share = state.Share
|
||||
n.owner = state.Owner
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid state type: %T", s)
|
||||
}
|
||||
|
||||
func (n *shareNavigator) Recycle() {
|
||||
if n.persist != nil {
|
||||
n.persist()
|
||||
n.persist = nil
|
||||
}
|
||||
|
||||
if !n.disableRecycle {
|
||||
if n.ownerRoot != nil {
|
||||
n.ownerRoot.Recycle()
|
||||
} else if n.shareRoot != nil {
|
||||
n.shareRoot.Recycle()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *shareNavigator) Root(ctx context.Context, path *fs.URI) (*File, error) {
|
||||
ctx = context.WithValue(ctx, inventory.LoadShareUser{}, true)
|
||||
ctx = context.WithValue(ctx, inventory.LoadUserGroup{}, true)
|
||||
ctx = context.WithValue(ctx, inventory.LoadShareFile{}, true)
|
||||
share, err := n.shareClient.GetByHashID(ctx, path.ID(hashid.EncodeUserID(n.hasher, n.user.ID)))
|
||||
if err != nil {
|
||||
return nil, ErrShareNotFound.WithError(err)
|
||||
}
|
||||
|
||||
if err := inventory.IsValidShare(share); err != nil {
|
||||
return nil, ErrShareNotFound.WithError(err)
|
||||
}
|
||||
|
||||
n.owner = share.Edges.User
|
||||
|
||||
// Check password
|
||||
if share.Password != "" && share.Password != path.Password() {
|
||||
return nil, ErrShareIncorrectPassword
|
||||
}
|
||||
|
||||
// Share permission setting should overwrite root folder's permission
|
||||
n.shareRoot = newFile(nil, share.Edges.File)
|
||||
|
||||
// Find the user side root of the file.
|
||||
ownerRoot, err := n.findRoot(ctx, n.shareRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n.shareRoot.Type() == types.FileTypeFile {
|
||||
n.singleFileShare = true
|
||||
n.shareRoot = n.shareRoot.Parent
|
||||
}
|
||||
|
||||
n.shareRoot.Path[pathIndexUser] = path.Root()
|
||||
n.shareRoot.OwnerModel = n.owner
|
||||
n.shareRoot.IsUserRoot = true
|
||||
n.shareRoot.CapabilitiesBs = n.Capabilities(false).Capability
|
||||
|
||||
// Check if any ancestors is deleted
|
||||
if ownerRoot.Name() != inventory.RootFolderName {
|
||||
return nil, ErrShareNotFound
|
||||
}
|
||||
|
||||
if n.user.ID != n.owner.ID && !n.user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionShareDownload)) {
|
||||
return nil, serializer.NewError(
|
||||
serializer.CodeNoPermissionErr,
|
||||
fmt.Sprintf("You don't have permission to access share links"),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
n.ownerRoot = ownerRoot
|
||||
n.ownerRoot.Path[pathIndexRoot] = newMyIDUri(hashid.EncodeUserID(n.hasher, n.owner.ID))
|
||||
n.share = share
|
||||
return n.shareRoot, nil
|
||||
}
|
||||
|
||||
func (n *shareNavigator) To(ctx context.Context, path *fs.URI) (*File, error) {
|
||||
if n.shareRoot == nil {
|
||||
root, err := n.Root(ctx, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n.shareRoot = root
|
||||
}
|
||||
|
||||
current, lastAncestor := n.shareRoot, n.shareRoot
|
||||
elements := path.Elements()
|
||||
|
||||
// If target is root of single file share, the root itself is the target.
|
||||
if len(elements) <= 1 && n.singleFileShare {
|
||||
file, err := n.latestSharedSingleFile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(elements) == 1 && file.Name() != elements[0] {
|
||||
return nil, fs.ErrPathNotExist
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
for index, element := range elements {
|
||||
lastAncestor = current
|
||||
current, err = n.walkNext(ctx, current, element, index == len(elements)-1)
|
||||
if err != nil {
|
||||
return lastAncestor, fmt.Errorf("failed to walk into %q: %w", element, err)
|
||||
}
|
||||
}
|
||||
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (n *shareNavigator) walkNext(ctx context.Context, root *File, next string, isLeaf bool) (*File, error) {
|
||||
nextFile, err := n.baseNavigator.walkNext(ctx, root, next, isLeaf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nextFile, nil
|
||||
}
|
||||
|
||||
func (n *shareNavigator) Children(ctx context.Context, parent *File, args *ListArgs) (*ListResult, error) {
|
||||
if n.singleFileShare {
|
||||
file, err := n.latestSharedSingleFile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListResult{
|
||||
Files: []*File{file},
|
||||
Pagination: &inventory.PaginationResults{},
|
||||
SingleFileView: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return n.baseNavigator.children(ctx, parent, args)
|
||||
}
|
||||
|
||||
func (n *shareNavigator) latestSharedSingleFile(ctx context.Context) (*File, error) {
|
||||
if n.singleFileShare {
|
||||
file, err := n.fileClient.GetByID(ctx, n.share.Edges.File.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := newFile(n.shareRoot, file)
|
||||
f.OwnerModel = n.shareRoot.OwnerModel
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrPathNotExist
|
||||
}
|
||||
|
||||
func (n *shareNavigator) Capabilities(isSearching bool) *fs.NavigatorProps {
|
||||
res := &fs.NavigatorProps{
|
||||
Capability: shareNavigatorCapability,
|
||||
OrderDirectionOptions: fullOrderDirectionOption,
|
||||
OrderByOptions: fullOrderByOption,
|
||||
MaxPageSize: n.config.MaxPageSize,
|
||||
}
|
||||
|
||||
if isSearching {
|
||||
res.OrderByOptions = nil
|
||||
res.OrderDirectionOptions = nil
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (n *shareNavigator) FollowTx(ctx context.Context) (func(), error) {
|
||||
if _, ok := ctx.Value(inventory.TxCtx{}).(*inventory.Tx); !ok {
|
||||
return nil, fmt.Errorf("navigator: no inherited transaction found in context")
|
||||
}
|
||||
newFileClient, _, _, err := inventory.WithTx(ctx, n.fileClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newSharClient, _, _, err := inventory.WithTx(ctx, n.shareClient)
|
||||
|
||||
oldFileClient, oldShareClient := n.fileClient, n.shareClient
|
||||
revert := func() {
|
||||
n.fileClient = oldFileClient
|
||||
n.shareClient = oldShareClient
|
||||
n.baseNavigator.fileClient = oldFileClient
|
||||
}
|
||||
|
||||
n.fileClient = newFileClient
|
||||
n.shareClient = newSharClient
|
||||
n.baseNavigator.fileClient = newFileClient
|
||||
return revert, nil
|
||||
}
|
||||
|
||||
func (n *shareNavigator) ExecuteHook(ctx context.Context, hookType fs.HookType, file *File) error {
|
||||
switch hookType {
|
||||
case fs.HookTypeBeforeDownload:
|
||||
if n.singleFileShare {
|
||||
return n.shareClient.Downloaded(ctx, n.share)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findRoot finds the root folder of the given child.
|
||||
func (n *shareNavigator) findRoot(ctx context.Context, child *File) (*File, error) {
|
||||
root := child
|
||||
for {
|
||||
newRoot, err := n.baseNavigator.walkUp(ctx, root)
|
||||
if err != nil {
|
||||
if !ent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
root = newRoot
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func (n *shareNavigator) Walk(ctx context.Context, levelFiles []*File, limit, depth int, f WalkFunc) error {
|
||||
return n.baseNavigator.walk(ctx, levelFiles, limit, depth, f)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user