Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
@@ -1,24 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/oss"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -31,14 +28,14 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var err error
|
||||
switch c.Request.Method {
|
||||
case "PUT", "POST", "PATCH":
|
||||
err = auth.CheckRequest(authInstance, c.Request)
|
||||
case http.MethodPut, http.MethodPost, http.MethodPatch:
|
||||
err = auth.CheckRequest(c, authInstance, c.Request)
|
||||
default:
|
||||
err = auth.CheckURI(authInstance, c.Request.URL)
|
||||
err = auth.CheckURI(c, authInstance, c.Request.URL)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -50,29 +47,55 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
|
||||
// CurrentUser 获取登录用户
|
||||
func CurrentUser() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
uid := session.Get("user_id")
|
||||
if uid != nil {
|
||||
user, err := model.GetActiveUserByID(uid)
|
||||
if err == nil {
|
||||
c.Set("user", &user)
|
||||
}
|
||||
dep := dependency.FromContext(c)
|
||||
shouldContinue, err := dep.TokenAuth().VerifyAndRetrieveUser(c)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(c, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if shouldContinue {
|
||||
// TODO: Logto handler
|
||||
}
|
||||
|
||||
uid := inventory.UserIDFromContext(c)
|
||||
if err := SetUserCtx(c, uid); err != nil {
|
||||
c.JSON(200, serializer.Err(c, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRequired 需要登录
|
||||
func AuthRequired() gin.HandlerFunc {
|
||||
// SetUserCtx set the current login user via uid
|
||||
func SetUserCtx(c *gin.Context, uid int) error {
|
||||
dep := dependency.FromContext(c)
|
||||
userClient := dep.UserClient()
|
||||
loginUser, err := userClient.GetLoginUserByID(c, uid)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeDBError, "failed to get login user", err)
|
||||
}
|
||||
|
||||
SetUserCtxByUser(c, loginUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetUserCtxByUser(c *gin.Context, user *ent.User) {
|
||||
util.WithValue(c, inventory.UserCtx{}, user)
|
||||
}
|
||||
|
||||
// LoginRequired 需要登录
|
||||
func LoginRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if user, _ := c.Get("user"); user != nil {
|
||||
if _, ok := user.(*model.User); ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if u := inventory.UserFromContext(c); u != nil && !inventory.IsAnonymousUser(u) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.CheckLogin())
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCheckLogin, "Login required", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
@@ -80,60 +103,84 @@ func AuthRequired() gin.HandlerFunc {
|
||||
// WebDAVAuth 验证WebDAV登录及权限
|
||||
func WebDAVAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// OPTIONS 请求不需要鉴权,否则Windows10下无法保存文档
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
if !ok {
|
||||
// OPTIONS 请求不需要鉴权
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
c.Writer.Header()["WWW-Authenticate"] = []string{`Basic realm="cloudreve"`}
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expectedUser, err := model.GetActiveUserByEmail(username)
|
||||
dep := dependency.FromContext(c)
|
||||
l := dep.Logger()
|
||||
userClient := dep.UserClient()
|
||||
expectedUser, err := userClient.GetActiveByDavAccount(c, username, password)
|
||||
if err != nil {
|
||||
if username == "" {
|
||||
if u, err := userClient.GetByEmail(c, username); err == nil {
|
||||
// Try login with known user but incorrect password, record audit log
|
||||
SetUserCtxByUser(c, u)
|
||||
}
|
||||
}
|
||||
|
||||
l.Debug("WebDAVAuth: failed to get user %q with provided credential: %s", username, err)
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 密码正确?
|
||||
webdav, err := model.GetWebdavByPassword(password, expectedUser.ID)
|
||||
if err != nil {
|
||||
// Validate dav account
|
||||
accounts, err := expectedUser.Edges.DavAccountsOrErr()
|
||||
if err != nil || len(accounts) == 0 {
|
||||
l.Debug("WebDAVAuth: failed to get user dav accounts %q with provided credential: %s", username, err)
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 用户组已启用WebDAV?
|
||||
if !expectedUser.Group.WebDAVEnabled {
|
||||
c.Status(http.StatusForbidden)
|
||||
group, err := expectedUser.Edges.GroupOrErr()
|
||||
if err != nil {
|
||||
l.Debug("WebDAVAuth: user group not found: %s", err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 用户组已启用WebDAV代理?
|
||||
if !expectedUser.Group.OptionsSerialized.WebDAVProxy {
|
||||
webdav.UseProxy = false
|
||||
if !group.Permissions.Enabled(int(types.GroupPermissionWebDAV)) {
|
||||
c.Status(http.StatusForbidden)
|
||||
l.Debug("WebDAVAuth: user %q does not have WebDAV permission.", expectedUser.Email)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user", &expectedUser)
|
||||
c.Set("webdav", webdav)
|
||||
// 检查是否只读
|
||||
if expectedUser.Edges.DavAccounts[0].Options.Enabled(int(types.DavAccountReadOnly)) {
|
||||
switch c.Request.Method {
|
||||
case http.MethodDelete, http.MethodPut, "MKCOL", "COPY", "MOVE", "LOCK", "UNLOCK":
|
||||
c.Status(http.StatusForbidden)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
SetUserCtxByUser(c, expectedUser)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 对上传会话进行验证
|
||||
func UseUploadSession(policyType string) gin.HandlerFunc {
|
||||
func UseUploadSession(policyType types.PolicyType) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 验证key并查找用户
|
||||
resp := uploadCallbackCheck(c, policyType)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(CallbackFailedStatusCode, resp)
|
||||
err := uploadCallbackCheck(c, policyType)
|
||||
if err != nil {
|
||||
c.JSON(CallbackFailedStatusCode, serializer.Err(c, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -143,44 +190,46 @@ func UseUploadSession(policyType string) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
// uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户
|
||||
func uploadCallbackCheck(c *gin.Context, policyType string) serializer.Response {
|
||||
func uploadCallbackCheck(c *gin.Context, policyType types.PolicyType) error {
|
||||
// 验证 Callback Key
|
||||
sessionID := c.Param("sessionID")
|
||||
if sessionID == "" {
|
||||
return serializer.ParamErr("Session ID cannot be empty", nil)
|
||||
return serializer.NewError(serializer.CodeParamErr, "Session ID cannot be empty", nil)
|
||||
}
|
||||
|
||||
callbackSessionRaw, exist := cache.Get(filesystem.UploadSessionCachePrefix + sessionID)
|
||||
dep := dependency.FromContext(c)
|
||||
callbackSessionRaw, exist := dep.KV().Get(manager.UploadSessionCachePrefix + sessionID)
|
||||
if !exist {
|
||||
return serializer.Err(serializer.CodeUploadSessionExpired, "上传会话不存在或已过期", nil)
|
||||
return serializer.NewError(serializer.CodeUploadSessionExpired, "Upload session does not exist or expired", nil)
|
||||
}
|
||||
|
||||
callbackSession := callbackSessionRaw.(serializer.UploadSession)
|
||||
c.Set(filesystem.UploadSessionCtx, &callbackSession)
|
||||
if callbackSession.Policy.Type != policyType {
|
||||
return serializer.Err(serializer.CodePolicyNotAllowed, "", nil)
|
||||
callbackSession := callbackSessionRaw.(fs.UploadSession)
|
||||
c.Set(manager.UploadSessionCtx, &callbackSession)
|
||||
if callbackSession.Policy.Type != string(policyType) {
|
||||
return serializer.NewError(serializer.CodePolicyNotAllowed, "", nil)
|
||||
}
|
||||
|
||||
// 清理回调会话
|
||||
_ = cache.Deletes([]string{sessionID}, filesystem.UploadSessionCachePrefix)
|
||||
|
||||
// 查找用户
|
||||
user, err := model.GetActiveUserByID(callbackSession.UID)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeUserNotFound, "", err)
|
||||
if err := SetUserCtx(c, callbackSession.UID); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Set(filesystem.UserCtx, &user)
|
||||
return serializer.Response{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteCallbackAuth 远程回调签名验证
|
||||
func RemoteCallbackAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 验证签名
|
||||
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.SecretKey)}
|
||||
if err := auth.CheckRequest(authInstance, c.Request); err != nil {
|
||||
c.JSON(CallbackFailedStatusCode, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
|
||||
session := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession)
|
||||
if session.Policy.Edges.Node == nil {
|
||||
c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, "Node not found", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.Edges.Node.SlaveKey)}
|
||||
if err := auth.CheckRequest(c, authInstance, c.Request); err != nil {
|
||||
c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -190,37 +239,16 @@ func RemoteCallbackAuth() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// QiniuCallbackAuth 七牛回调签名验证
|
||||
func QiniuCallbackAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
|
||||
|
||||
// 验证回调是否来自qiniu
|
||||
mac := qbox.NewMac(session.Policy.AccessKey, session.Policy.SecretKey)
|
||||
ok, err := mac.VerifyCallback(c.Request)
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to verify callback request: %s", err)
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Invalid signature."})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OSSCallbackAuth 阿里云OSS回调签名验证
|
||||
func OSSCallbackAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
err := oss.VerifyCallbackSignature(c.Request)
|
||||
dep := dependency.FromContext(c)
|
||||
err := oss.VerifyCallbackSignature(c.Request, dep.KV(), dep.RequestClient(
|
||||
request.WithContext(c),
|
||||
request.WithLogger(logging.FromContext(c)),
|
||||
))
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to verify callback request: %s", err)
|
||||
dep.Logger().Debug("Failed to verify callback request: %s", err)
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
|
||||
c.Abort()
|
||||
return
|
||||
@@ -230,71 +258,12 @@ func OSSCallbackAuth() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpyunCallbackAuth 又拍云回调签名验证
|
||||
func UpyunCallbackAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
|
||||
|
||||
// 获取请求正文
|
||||
body, err := ioutil.ReadAll(c.Request.Body)
|
||||
c.Request.Body.Close()
|
||||
if err != nil {
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
|
||||
// 准备验证Upyun回调签名
|
||||
handler := upyun.Driver{Policy: &session.Policy}
|
||||
contentMD5 := c.Request.Header.Get("Content-Md5")
|
||||
date := c.Request.Header.Get("Date")
|
||||
actualSignature := c.Request.Header.Get("Authorization")
|
||||
|
||||
// 计算正文MD5
|
||||
actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body))
|
||||
if actualContentMD5 != contentMD5 {
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5 mismatch."})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 计算理论签名
|
||||
signature := handler.Sign(context.Background(), []string{
|
||||
"POST",
|
||||
c.Request.URL.Path,
|
||||
date,
|
||||
contentMD5,
|
||||
})
|
||||
|
||||
// 对比签名
|
||||
if signature != actualSignature {
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Signature not match"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OneDriveCallbackAuth OneDrive回调签名验证
|
||||
func OneDriveCallbackAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 发送回调结束信号
|
||||
mq.GlobalMQ.Publish(c.Param("sessionID"), mq.Message{})
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsAdmin 必须为管理员用户组
|
||||
func IsAdmin() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
user, _ := c.Get("user")
|
||||
if user.(*model.User).Group.ID != 1 && user.(*model.User).ID != 1 {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "", nil))
|
||||
user := inventory.UserFromContext(c)
|
||||
if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) {
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeNoPermissionErr, "", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,605 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestCurrentUser(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
|
||||
//session为空
|
||||
sessionFunc := Session("233")
|
||||
sessionFunc(c)
|
||||
CurrentUser()(c)
|
||||
user, _ := c.Get("user")
|
||||
asserts.Nil(user)
|
||||
|
||||
//session正确
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
util.SetSession(c, map[string]interface{}{"user_id": 1})
|
||||
rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}).
|
||||
AddRow(1, nil, "admin@cloudreve.org", "{}")
|
||||
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows)
|
||||
CurrentUser()(c)
|
||||
user, _ = c.Get("user")
|
||||
asserts.NotNil(user)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestAuthRequired(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
AuthRequiredFunc := AuthRequired()
|
||||
|
||||
// 未登录
|
||||
AuthRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
|
||||
// 类型错误
|
||||
c.Set("user", 123)
|
||||
AuthRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
|
||||
// 正常
|
||||
c.Set("user", &model.User{})
|
||||
AuthRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
}
|
||||
|
||||
func TestSignRequired(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
SignRequiredFunc := SignRequired(authInstance)
|
||||
|
||||
// 鉴权失败
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
// Sign verify success
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||
c.Request = auth.SignRequest(authInstance, c.Request, 0)
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
func TestWebDAVAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := WebDAVAuth()
|
||||
|
||||
// options请求跳过验证
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
|
||||
AuthFunc(c)
|
||||
}
|
||||
|
||||
// 请求HTTP Basic Auth
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
AuthFunc(c)
|
||||
asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
|
||||
}
|
||||
|
||||
// 用户名不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id", "password", "email"}),
|
||||
)
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// 密码错误
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
|
||||
)
|
||||
// 查找密码
|
||||
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
//未启用 WebDAV
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows(
|
||||
[]string{"id", "password", "email", "group_id", "options"}).
|
||||
AddRow(1,
|
||||
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
|
||||
"who@cloudreve.org",
|
||||
1,
|
||||
"{}",
|
||||
),
|
||||
)
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
|
||||
// 查找密码
|
||||
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), http.StatusForbidden)
|
||||
}
|
||||
|
||||
//正常
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows(
|
||||
[]string{"id", "password", "email", "group_id", "options"}).
|
||||
AddRow(1,
|
||||
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
|
||||
"who@cloudreve.org",
|
||||
1,
|
||||
"{}",
|
||||
),
|
||||
)
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
|
||||
// 查找密码
|
||||
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), 200)
|
||||
_, ok := c.Get("user")
|
||||
asserts.True(ok)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestUseUploadSession(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := UseUploadSession("local")
|
||||
|
||||
// sessionID 为空
|
||||
{
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil)
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
|
||||
auth.SignRequest(authInstance, c.Request, 0)
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
cache.Set(
|
||||
filesystem.UploadSessionCachePrefix+"testCallBackRemote",
|
||||
serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{Type: "local"},
|
||||
},
|
||||
0,
|
||||
)
|
||||
cache.Deletes([]string{"1"}, "policy_")
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]"))
|
||||
mock.ExpectQuery("SELECT(.+)policies(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"sessionID", "testCallBackRemote"},
|
||||
}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
|
||||
auth.SignRequest(authInstance, c.Request, 0)
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadCallbackCheck(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 上传会话不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"sessionID", "testSessionNotExist"},
|
||||
}
|
||||
res := uploadCallbackCheck(c, "local")
|
||||
a.Contains("上传会话不存在或已过期", res.Msg)
|
||||
}
|
||||
|
||||
// 上传策略不一致
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"sessionID", "testPolicyNotMatch"},
|
||||
}
|
||||
cache.Set(
|
||||
filesystem.UploadSessionCachePrefix+"testPolicyNotMatch",
|
||||
serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{Type: "remote"},
|
||||
},
|
||||
0,
|
||||
)
|
||||
res := uploadCallbackCheck(c, "local")
|
||||
a.Contains("Policy not supported", res.Msg)
|
||||
}
|
||||
|
||||
// 用户不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"sessionID", "testUserNotExist"},
|
||||
}
|
||||
cache.Set(
|
||||
filesystem.UploadSessionCachePrefix+"testUserNotExist",
|
||||
serializer.UploadSession{
|
||||
UID: 313,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{Type: "remote"},
|
||||
},
|
||||
0,
|
||||
)
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
|
||||
res := uploadCallbackCheck(c, "remote")
|
||||
a.Contains("找不到用户", res.Msg)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
_, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist")
|
||||
a.False(ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteCallbackAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := RemoteCallbackAuth()
|
||||
|
||||
// 成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{SecretKey: "123"},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
|
||||
auth.SignRequest(authInstance, c.Request, 0)
|
||||
AuthFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 签名错误
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{SecretKey: "123"},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQiniuCallbackAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := QiniuCallbackAuth()
|
||||
|
||||
// 成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
|
||||
mac := qbox.NewMac("123", "123")
|
||||
token, err := mac.SignRequest(c.Request)
|
||||
asserts.NoError(err)
|
||||
c.Request.Header["Authorization"] = []string{"QBox " + token}
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 验证失败
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
|
||||
mac := qbox.NewMac("123", "1213")
|
||||
token, err := mac.SignRequest(c.Request)
|
||||
asserts.NoError(err)
|
||||
c.Request.Header["Authorization"] = []string{"QBox " + token}
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOSSCallbackAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := OSSCallbackAuth()
|
||||
|
||||
// 签名验证失败
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil)
|
||||
mac := qbox.NewMac("123", "123")
|
||||
token, err := mac.SignRequest(c.Request)
|
||||
asserts.NoError(err)
|
||||
c.Request.Header["Authorization"] = []string{"QBox " + token}
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`)))
|
||||
c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="}
|
||||
c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="}
|
||||
AuthFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type fakeRead string
|
||||
|
||||
func (r fakeRead) Read(p []byte) (int, error) {
|
||||
return 0, errors.New("error")
|
||||
}
|
||||
|
||||
func TestUpyunCallbackAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := UpyunCallbackAuth()
|
||||
|
||||
// 无法获取请求正文
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead("")))
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 正文MD5不一致
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
|
||||
c.Request.Header["Content-Md5"] = []string{"123"}
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 签名不一致
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
|
||||
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
|
||||
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
|
||||
c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="}
|
||||
AuthFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOneDriveCallbackAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := OneDriveCallbackAuth()
|
||||
|
||||
// 成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"sessionID", "TestOneDriveCallbackAuth"},
|
||||
}
|
||||
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
|
||||
UID: 1,
|
||||
VirtualPath: "/",
|
||||
Policy: model.Policy{
|
||||
SecretKey: "123",
|
||||
AccessKey: "123",
|
||||
},
|
||||
})
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1")))
|
||||
res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1)
|
||||
AuthFunc(c)
|
||||
select {
|
||||
case <-res:
|
||||
case <-time.After(time.Millisecond * 500):
|
||||
asserts.Fail("mq message should be published")
|
||||
}
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAdmin(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := IsAdmin()
|
||||
|
||||
// 非管理员
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("user", &model.User{})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 是管理员
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
user := &model.User{}
|
||||
user.Group.ID = 1
|
||||
c.Set("user", user)
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 初始用户,非管理组
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
user := &model.User{}
|
||||
user.Group.ID = 2
|
||||
user.ID = 1
|
||||
c.Set("user", user)
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
@@ -3,52 +3,56 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/recaptcha"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/recaptcha"
|
||||
request2 "github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
captcha "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha/v20190722"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type req struct {
|
||||
CaptchaCode string `json:"captchaCode"`
|
||||
Ticket string `json:"ticket"`
|
||||
Randstr string `json:"randstr"`
|
||||
Captcha string `json:"captcha"`
|
||||
Ticket string `json:"ticket"`
|
||||
Randstr string `json:"randstr"`
|
||||
}
|
||||
|
||||
const (
|
||||
captchaNotMatch = "CAPTCHA not match."
|
||||
captchaRefresh = "Verification failed, please refresh the page and retry."
|
||||
|
||||
tcCaptchaEndpoint = "captcha.tencentcloudapi.com"
|
||||
turnstileEndpoint = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
)
|
||||
|
||||
// CaptchaIDCtx defines keys for captcha ID
|
||||
type (
|
||||
CaptchaIDCtx struct{}
|
||||
turnstileResponse struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
)
|
||||
|
||||
// CaptchaRequired 验证请求签名
|
||||
func CaptchaRequired(configName string) gin.HandlerFunc {
|
||||
func CaptchaRequired(enabled func(c *gin.Context) bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 相关设定
|
||||
options := model.GetSettingByNames(configName,
|
||||
"captcha_type",
|
||||
"captcha_ReCaptchaSecret",
|
||||
"captcha_TCaptcha_SecretId",
|
||||
"captcha_TCaptcha_SecretKey",
|
||||
"captcha_TCaptcha_CaptchaAppId",
|
||||
"captcha_TCaptcha_AppSecretKey")
|
||||
// 检查验证码
|
||||
isCaptchaRequired := model.IsTrueVal(options[configName])
|
||||
if enabled(c) {
|
||||
dep := dependency.FromContext(c)
|
||||
settings := dep.SettingProvider()
|
||||
l := logging.FromContext(c)
|
||||
|
||||
if isCaptchaRequired {
|
||||
var service req
|
||||
bodyCopy := new(bytes.Buffer)
|
||||
_, err := io.Copy(bodyCopy, c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -56,65 +60,69 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
|
||||
bodyData := bodyCopy.Bytes()
|
||||
err = json.Unmarshal(bodyData, &service)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = ioutil.NopCloser(bytes.NewReader(bodyData))
|
||||
switch options["captcha_type"] {
|
||||
case "normal":
|
||||
captchaID := util.GetSession(c, "captchaID")
|
||||
util.DeleteSession(c, "captchaID")
|
||||
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyData))
|
||||
switch settings.CaptchaType(c) {
|
||||
case setting.CaptchaNormal, setting.CaptchaTcaptcha:
|
||||
if service.Ticket == "" || !base64Captcha.VerifyCaptcha(service.Ticket, service.Captcha) {
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
case "recaptcha":
|
||||
reCAPTCHA, err := recaptcha.NewReCAPTCHA(options["captcha_ReCaptchaSecret"], recaptcha.V2, 10*time.Second)
|
||||
case setting.CaptchaReCaptcha:
|
||||
captchaSetting := settings.ReCaptcha(c)
|
||||
reCAPTCHA, err := recaptcha.NewReCAPTCHA(captchaSetting.Secret, recaptcha.V2, 10*time.Second)
|
||||
if err != nil {
|
||||
util.Log().Warning("reCAPTCHA verification failed, %s", err)
|
||||
l.Warning("reCAPTCHA verification failed, %s", err)
|
||||
c.Abort()
|
||||
break
|
||||
}
|
||||
|
||||
err = reCAPTCHA.Verify(service.CaptchaCode)
|
||||
err = reCAPTCHA.Verify(service.Captcha)
|
||||
if err != nil {
|
||||
util.Log().Warning("reCAPTCHA verification failed, %s", err)
|
||||
c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil))
|
||||
l.Warning("reCAPTCHA verification failed, %s", err)
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaRefresh, err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
case "tcaptcha":
|
||||
credential := common.NewCredential(
|
||||
options["captcha_TCaptcha_SecretId"],
|
||||
options["captcha_TCaptcha_SecretKey"],
|
||||
case setting.CaptchaTurnstile:
|
||||
captchaSetting := settings.TurnstileCaptcha(c)
|
||||
r := dep.RequestClient(
|
||||
request2.WithContext(c),
|
||||
request2.WithLogger(logging.FromContext(c)),
|
||||
request2.WithHeader(http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}),
|
||||
)
|
||||
cpf := profile.NewClientProfile()
|
||||
cpf.HttpProfile.Endpoint = "captcha.tencentcloudapi.com"
|
||||
client, _ := captcha.NewClient(credential, "", cpf)
|
||||
request := captcha.NewDescribeCaptchaResultRequest()
|
||||
request.CaptchaType = common.Uint64Ptr(9)
|
||||
appid, _ := strconv.Atoi(options["captcha_TCaptcha_CaptchaAppId"])
|
||||
request.CaptchaAppId = common.Uint64Ptr(uint64(appid))
|
||||
request.AppSecretKey = common.StringPtr(options["captcha_TCaptcha_AppSecretKey"])
|
||||
request.Ticket = common.StringPtr(service.Ticket)
|
||||
request.Randstr = common.StringPtr(service.Randstr)
|
||||
request.UserIp = common.StringPtr(c.ClientIP())
|
||||
response, err := client.DescribeCaptchaResult(request)
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", captchaSetting.Secret)
|
||||
formData.Set("response", service.Ticket)
|
||||
res, err := r.Request("POST", turnstileEndpoint, strings.NewReader(formData.Encode())).
|
||||
CheckHTTPResponse(http.StatusOK).
|
||||
GetResponse()
|
||||
if err != nil {
|
||||
util.Log().Warning("TCaptcha verification failed, %s", err)
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err))
|
||||
c.Abort()
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
if *response.Response.CaptchaCode != int64(1) {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil))
|
||||
var trunstileRes turnstileResponse
|
||||
err = json.Unmarshal([]byte(res), &trunstileRes)
|
||||
if err != nil {
|
||||
l.Warning("Turnstile verification failed, %s", err)
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if !trunstileRes.Success {
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type errReader int
|
||||
|
||||
func (errReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("test error")
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_General(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 未启用验证码
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "0",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// body 无法读取
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", errReader(1))
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// body JSON 解析失败
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("123"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Normal(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 验证码错误
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "normal",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
Session("233")(c)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Recaptcha(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 无法初始化reCaptcha实例
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "recaptcha",
|
||||
"captcha_ReCaptchaSecret": "",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 验证码错误
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "recaptcha",
|
||||
"captcha_ReCaptchaSecret": "233",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Tcaptcha(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 验证出错
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "tcaptcha",
|
||||
"captcha_ReCaptchaSecret": "",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
@@ -1,62 +1,75 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/routers/controllers"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
|
||||
func MasterMetadata() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set("MasterSiteID", c.GetHeader(auth.CrHeaderPrefix+"Site-Id"))
|
||||
c.Set("MasterSiteURL", c.GetHeader(auth.CrHeaderPrefix+"Site-Url"))
|
||||
c.Set("MasterVersion", c.GetHeader(auth.CrHeaderPrefix+"Cloudreve-Version"))
|
||||
c.Next()
|
||||
}
|
||||
type SlaveNodeSettingGetter interface {
|
||||
// GetNodeSetting returns the node settings and its hash
|
||||
GetNodeSetting() (*types.NodeSetting, string)
|
||||
}
|
||||
|
||||
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
|
||||
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := clusterController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNotSet, "Failed to get Aria2 instance", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
var downloaderPool = sync.Map{}
|
||||
|
||||
c.Set("MasterAria2Instance", caller)
|
||||
// PrepareSlaveDownloader creates or resume a downloader based on input node settings
|
||||
func PrepareSlaveDownloader(dep dependency.Dep, ctxKey interface{}) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nodeSettings, hash := controllers.ParametersFromContext[SlaveNodeSettingGetter](c, ctxKey).GetNodeSetting()
|
||||
|
||||
// try to get downloader from pool
|
||||
if d, ok := downloaderPool.Load(hash); ok {
|
||||
c.Set(downloader.DownloaderCtxKey, d)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.ParamErr("Unknown master node ID", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nodeID, err := strconv.ParseUint(c.GetHeader(auth.CrHeaderPrefix+"Node-Id"), 10, 64)
|
||||
// create a new downloader
|
||||
d, err := cluster.NewDownloader(c, dep.RequestClient(request.WithContext(c), request.WithLogger(dep.Logger())), dep.SettingProvider(), nodeSettings)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("Unknown master node ID", err))
|
||||
c.JSON(200, serializer.ParamErr(c, "Failed to create downloader", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
slaveNode := nodePool.GetNodeByID(uint(nodeID))
|
||||
if slaveNode == nil {
|
||||
c.JSON(200, serializer.ParamErr("Unknown master node ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
SignRequired(slaveNode.MasterAuthInstance())(c)
|
||||
|
||||
// save downloader to pool
|
||||
downloaderPool.Store(hash, d)
|
||||
c.Set(downloader.DownloaderCtxKey, d)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SlaveRPCSignRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nodeId := cluster.NodeIdFromContext(c)
|
||||
if nodeId == 0 {
|
||||
c.JSON(200, serializer.ParamErr(c, "Unknown node ID", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
np, err := dependency.FromContext(c).NodePool(c)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.NewError(serializer.CodeInternalSetting, "Failed to get node pool", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
slaveNode, err := np.Get(c, types.NodeCapabilityNone, nodeId)
|
||||
if slaveNode == nil || slaveNode.IsMaster() {
|
||||
c.JSON(200, serializer.ParamErr(c, "Unknown node ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
SignRequired(slaveNode.AuthInstance())(c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMasterMetadata(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
masterMetaDataFunc := MasterMetadata()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
c.Request.Header = map[string][]string{
|
||||
"X-Cr-Site-Id": {"expectedSiteID"},
|
||||
"X-Cr-Site-Url": {"expectedSiteURL"},
|
||||
"X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
|
||||
}
|
||||
masterMetaDataFunc(c)
|
||||
siteID, _ := c.Get("MasterSiteID")
|
||||
siteURL, _ := c.Get("MasterSiteURL")
|
||||
siteVersion, _ := c.Get("MasterVersion")
|
||||
|
||||
a.Equal("expectedSiteID", siteID.(string))
|
||||
a.Equal("expectedSiteURL", siteURL.(string))
|
||||
a.Equal("expectedMasterVersion", siteVersion.(string))
|
||||
}
|
||||
|
||||
func TestSlaveRPCSignRequired(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
np := &cluster.NodePool{}
|
||||
np.Init()
|
||||
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// id parse failed
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("X-Cr-Node-Id", "unknown")
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// node id not exist
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("X-Cr-Node-Id", "38")
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte("")}
|
||||
np.Add(&model.Node{Model: gorm.Model{
|
||||
ID: 38,
|
||||
}})
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
c.Request.Header.Set("X-Cr-Node-Id", "38")
|
||||
c.Request = auth.SignRequest(authInstance, c.Request, 0)
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseSlaveAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
// MasterSiteID not set
|
||||
{
|
||||
testController := &controllermock.SlaveControllerMock{}
|
||||
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
useSlaveAria2InstanceFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// Cannot get aria2 instances
|
||||
{
|
||||
testController := &controllermock.SlaveControllerMock{}
|
||||
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Set("MasterSiteID", "expectedSiteID")
|
||||
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
|
||||
useSlaveAria2InstanceFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
testController.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Success
|
||||
{
|
||||
testController := &controllermock.SlaveControllerMock{}
|
||||
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Set("MasterSiteID", "expectedSiteID")
|
||||
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
|
||||
useSlaveAria2InstanceFunc(c)
|
||||
a.False(c.IsAborted())
|
||||
res, _ := c.Get("MasterAria2Instance")
|
||||
a.NotNil(res)
|
||||
testController.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
@@ -1,26 +1,35 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth/requestinfo"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HashID 将给定对象的HashID转换为真实ID
|
||||
func HashID(IDType int) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
dep := dependency.FromContext(c)
|
||||
if c.Param("id") != "" {
|
||||
id, err := hashid.DecodeHashID(c.Param("id"), IDType)
|
||||
id, err := dep.HashIDEncoder().Decode(c.Param("id"), IDType)
|
||||
if err == nil {
|
||||
c.Set("object_id", id)
|
||||
util.WithValue(c, hashid.ObjectIDCtx{}, id)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
c.JSON(200, serializer.ParamErr("Failed to parse object ID", nil))
|
||||
c.JSON(200, serializer.ParamErr(c, "Failed to parse object ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
|
||||
@@ -30,10 +39,10 @@ func HashID(IDType int) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
// IsFunctionEnabled 当功能未开启时阻止访问
|
||||
func IsFunctionEnabled(key string) gin.HandlerFunc {
|
||||
func IsFunctionEnabled(check func(c *gin.Context) bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !model.IsTrueVal(model.GetSettingByName(key)) {
|
||||
c.JSON(200, serializer.Err(serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil))
|
||||
if !check(c) {
|
||||
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -56,9 +65,10 @@ func Sandbox() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
// StaticResourceCache 使用静态资源缓存策略
|
||||
func StaticResourceCache() gin.HandlerFunc {
|
||||
func StaticResourceCache(dep dependency.Dep) gin.HandlerFunc {
|
||||
settings := dep.SettingProvider()
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400)))
|
||||
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", settings.PublicResourceMaxAge(c)))
|
||||
|
||||
}
|
||||
}
|
||||
@@ -66,8 +76,9 @@ func StaticResourceCache() gin.HandlerFunc {
|
||||
// MobileRequestOnly
|
||||
func MobileRequestOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader(auth.CrHeaderPrefix+"ios") == "" {
|
||||
c.Redirect(http.StatusMovedPermanently, model.GetSiteURL().String())
|
||||
dep := dependency.FromContext(c)
|
||||
if c.GetHeader(constants.CrHeaderPrefix+"ios") == "" {
|
||||
c.Redirect(http.StatusMovedPermanently, dep.SettingProvider().SiteURL(c).String())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -75,3 +86,66 @@ func MobileRequestOnly() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeHandling is added at the beginning of handler chain, it did following setups:
|
||||
// 1. Inject dependency manager into request context
|
||||
// 2. Generate and inject correlation ID for diagnostic.
|
||||
func InitializeHandling(dep dependency.Dep) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
reqInfo := &requestinfo.RequestInfo{
|
||||
IP: c.ClientIP(),
|
||||
Host: c.Request.Host,
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
}
|
||||
cid := uuid.FromStringOrNil(c.GetHeader(request.CorrelationHeader))
|
||||
if cid == uuid.Nil {
|
||||
cid = uuid.Must(uuid.NewV4())
|
||||
}
|
||||
|
||||
l := dep.Logger().CopyWithPrefix(fmt.Sprintf("[Cid: %s]", cid))
|
||||
ctx := dep.ForkWithLogger(c.Request.Context(), l)
|
||||
ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, cid)
|
||||
ctx = context.WithValue(ctx, requestinfo.RequestInfoCtx{}, reqInfo)
|
||||
ctx = context.WithValue(ctx, logging.LoggerCtx{}, l)
|
||||
if id := c.Param("nodeId"); id != "" {
|
||||
ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, id)
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, c.GetHeader(request.SlaveNodeIDHeader))
|
||||
}
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeHandlingSlave retrieves coll correlation ID and other metadata from request header
|
||||
func InitializeHandlingSlave() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := context.WithValue(c.Request.Context(), cluster.MasterSiteIDCtx{}, c.GetHeader(request.SiteIDHeader))
|
||||
ctx = context.WithValue(ctx, cluster.MasterSiteUrlCtx{}, c.GetHeader(request.SiteURLHeader))
|
||||
ctx = context.WithValue(ctx, cluster.MasterSiteVersionCtx{}, c.GetHeader(request.SiteVersionHeader))
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Logging logs incoming request info
|
||||
func Logging() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
l := logging.FromContext(c)
|
||||
logging.Request(l, true, c.Writer.Status(), c.Request.Method, c.ClientIP(), path,
|
||||
c.Errors.ByType(gin.ErrorTypePrivate).String(), start)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHashID(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
TestFunc := HashID(hashid.FolderID)
|
||||
|
||||
// 未给定ID对象,跳过
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 给定ID,解析失败
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"id", "2333"},
|
||||
}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 给定ID,解析成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"id", hashid.HashID(1, hashid.FolderID)},
|
||||
}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsFunctionEnabled(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
TestFunc := IsFunctionEnabled("TestIsFunctionEnabled")
|
||||
|
||||
// 未开启
|
||||
{
|
||||
cache.Set("setting_TestIsFunctionEnabled", "0", 0)
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
// 开启
|
||||
{
|
||||
cache.Set("setting_TestIsFunctionEnabled", "1", 0)
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCacheControl(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
TestFunc := CacheControl()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
TestFunc(c)
|
||||
a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
|
||||
}
|
||||
|
||||
func TestSandbox(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
TestFunc := Sandbox()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
TestFunc(c)
|
||||
a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
|
||||
}
|
||||
|
||||
func TestStaticResourceCache(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
TestFunc := StaticResourceCache()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
TestFunc(c)
|
||||
a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
|
||||
}
|
||||
@@ -1,30 +1,49 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/routers/controllers"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
)
|
||||
|
||||
// ValidateSourceLink validates if the perm source link is a valid redirect link
|
||||
func ValidateSourceLink() gin.HandlerFunc {
|
||||
// UrisService is a wrapper for service supports batch file operations
|
||||
type UrisService interface {
|
||||
GetUris() []string
|
||||
}
|
||||
|
||||
// ValidateBatchFileCount validates if the batch file count is within the limit
|
||||
func ValidateBatchFileCount(dep dependency.Dep, ctxKey interface{}) gin.HandlerFunc {
|
||||
settings := dep.SettingProvider()
|
||||
return func(c *gin.Context) {
|
||||
linkID, ok := c.Get("object_id")
|
||||
if !ok {
|
||||
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
|
||||
uris := controllers.ParametersFromContext[UrisService](c, ctxKey)
|
||||
limit := settings.MaxBatchedFile(c)
|
||||
if len((uris).GetUris()) > limit {
|
||||
c.JSON(200, serializer.ErrWithDetails(
|
||||
c,
|
||||
serializer.CodeTooManyUris,
|
||||
fmt.Sprintf("Maximum allowed batch size: %d", limit),
|
||||
nil,
|
||||
))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
sourceLink, err := model.GetSourceLinkByID(linkID)
|
||||
if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") {
|
||||
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
sourceLink.Downloaded()
|
||||
c.Set("source_link", sourceLink)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ContextHint parses the context hint header and set it to context
|
||||
func ContextHint() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader(dbfs.ContextHintHeader) != "" {
|
||||
util.WithValue(c, dbfs.ContextHintCtxKey{}, uuid.FromStringOrNil(c.GetHeader(dbfs.ContextHintHeader)))
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateSourceLink(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := ValidateSourceLink()
|
||||
|
||||
// ID 不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// SourceLink 不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("object_id", 1)
|
||||
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
testFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 原文件不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("object_id", 1)
|
||||
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
testFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("object_id", 1)
|
||||
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
|
||||
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
testFunc(c)
|
||||
a.False(c.IsAborted())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,73 +1,82 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/bootstrap"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FrontendFileHandler 前端静态文件处理
|
||||
func FrontendFileHandler() gin.HandlerFunc {
|
||||
func FrontendFileHandler(dep dependency.Dep) gin.HandlerFunc {
|
||||
fs := dep.ServerStaticFS()
|
||||
l := dep.Logger()
|
||||
|
||||
ignoreFunc := func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
|
||||
if bootstrap.StaticFS == nil {
|
||||
if fs == nil {
|
||||
return ignoreFunc
|
||||
}
|
||||
|
||||
// 读取index.html
|
||||
file, err := bootstrap.StaticFS.Open("/index.html")
|
||||
file, err := fs.Open("/index.html")
|
||||
if err != nil {
|
||||
util.Log().Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.")
|
||||
l.Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.")
|
||||
return ignoreFunc
|
||||
}
|
||||
|
||||
fileContentBytes, err := ioutil.ReadAll(file)
|
||||
fileContentBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
util.Log().Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.")
|
||||
l.Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.")
|
||||
return ignoreFunc
|
||||
}
|
||||
fileContent := string(fileContentBytes)
|
||||
|
||||
fileServer := http.FileServer(bootstrap.StaticFS)
|
||||
fileServer := http.FileServer(fs)
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// API 跳过
|
||||
// Skipping routers handled by backend
|
||||
if strings.HasPrefix(path, "/api") ||
|
||||
strings.HasPrefix(path, "/custom") ||
|
||||
strings.HasPrefix(path, "/dav") ||
|
||||
strings.HasPrefix(path, "/f") ||
|
||||
strings.HasPrefix(path, "/f/") ||
|
||||
strings.HasPrefix(path, "/s/") ||
|
||||
path == "/manifest.json" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 不存在的路径和index.html均返回index.html
|
||||
if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) {
|
||||
if (path == "/index.html") || (path == "/") || !fs.Exists("/", path) {
|
||||
// 读取、替换站点设置
|
||||
options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript",
|
||||
"pwa_small_icon")
|
||||
settingClient := dep.SettingProvider()
|
||||
siteBasic := settingClient.SiteBasic(c)
|
||||
pwaOpts := settingClient.PWA(c)
|
||||
theme := settingClient.Theme(c)
|
||||
finalHTML := util.Replace(map[string]string{
|
||||
"{siteName}": options["siteName"],
|
||||
"{siteDes}": options["siteDes"],
|
||||
"{siteScript}": options["siteScript"],
|
||||
"{pwa_small_icon}": options["pwa_small_icon"],
|
||||
"{siteName}": siteBasic.Name,
|
||||
"{siteDes}": siteBasic.Description,
|
||||
"{siteScript}": siteBasic.Script,
|
||||
"{pwa_small_icon}": pwaOpts.SmallIcon,
|
||||
"{pwa_medium_icon}": pwaOpts.MediumIcon,
|
||||
"var(--defaultThemeColor)": theme.DefaultTheme,
|
||||
}, fileContent)
|
||||
|
||||
c.Header("Content-Type", "text/html")
|
||||
c.Header("Cache-Control", "public, no-cache")
|
||||
c.String(200, finalHTML)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if path == "/service-worker.js" {
|
||||
if path == "/sw.js" || strings.HasPrefix(path, "/locales/") {
|
||||
c.Header("Cache-Control", "public, no-cache")
|
||||
} else if strings.HasPrefix(path, "/assets/") {
|
||||
c.Header("Cache-Control", "public, max-age=31536000")
|
||||
}
|
||||
|
||||
// 存在的静态文件
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/bootstrap"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type StaticMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (m StaticMock) Open(name string) (http.File, error) {
|
||||
args := m.Called(name)
|
||||
return args.Get(0).(http.File), args.Error(1)
|
||||
}
|
||||
|
||||
func (m StaticMock) Exists(prefix string, filepath string) bool {
|
||||
args := m.Called(prefix, filepath)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func TestFrontendFileHandler(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 静态资源未加载
|
||||
{
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// index.html 不存在
|
||||
{
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(&os.File{}, errors.New("error"))
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// index.html 读取失败
|
||||
{
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功且命中
|
||||
{
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
defer file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
|
||||
cache.Set("setting_siteName", "cloudreve", 0)
|
||||
cache.Set("setting_siteKeywords", "cloudreve", 0)
|
||||
cache.Set("setting_siteScript", "cloudreve", 0)
|
||||
cache.Set("setting_pwa_small_icon", "cloudreve", 0)
|
||||
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功且命中静态文件
|
||||
{
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
defer file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
testStatic.On("Exists", "/", "/2").
|
||||
Return(true)
|
||||
testStatic.On("Open", "/2").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/2", nil)
|
||||
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
testStatic.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// API 相关跳过
|
||||
{
|
||||
for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} {
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
defer file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", reqPath, nil)
|
||||
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMockHelper(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
MockHelperFunc := MockHelper()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
|
||||
// 写入session
|
||||
{
|
||||
SessionMock["test"] = "pass"
|
||||
Session("test")(c)
|
||||
MockHelperFunc(c)
|
||||
asserts.Equal("pass", util.GetSession(c, "test").(string))
|
||||
}
|
||||
|
||||
// 写入context
|
||||
{
|
||||
ContextMock["test"] = "pass"
|
||||
MockHelperFunc(c)
|
||||
test, exist := c.Get("test")
|
||||
asserts.True(exist)
|
||||
asserts.Equal("pass", test.(string))
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/sessionstore"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
// Store session存储
|
||||
var Store sessions.Store
|
||||
|
||||
// Session 初始化session
|
||||
func Session(secret string) gin.HandlerFunc {
|
||||
// Redis设置不为空,且非测试模式时使用Redis
|
||||
Store = sessionstore.NewStore(cache.Store, []byte(secret))
|
||||
const SessionName = "cloudreve-session"
|
||||
|
||||
// Session 初始化session
|
||||
func Session(dep dependency.Dep) gin.HandlerFunc {
|
||||
Store = sessionstore.NewStore(dep.KV(), []byte(dep.ConfigProvider().System().SessionSecret))
|
||||
sameSiteMode := http.SameSiteDefaultMode
|
||||
switch strings.ToLower(conf.CORSConfig.SameSite) {
|
||||
case "default":
|
||||
@@ -42,7 +42,7 @@ func Session(secret string) gin.HandlerFunc {
|
||||
Secure: conf.CORSConfig.Secure,
|
||||
})
|
||||
|
||||
return sessions.Sessions("cloudreve-session", Store)
|
||||
return sessions.Sessions(SessionName, Store)
|
||||
}
|
||||
|
||||
// CSRFInit 初始化CSRF标记
|
||||
@@ -61,7 +61,7 @@ func CSRFCheck() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "Invalid origin", nil))
|
||||
c.JSON(200, serializer.ErrDeprecated(serializer.CodeNoPermissionErr, "Invalid origin", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSession(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
{
|
||||
handler := Session("2333")
|
||||
asserts.NotNil(handler)
|
||||
asserts.NotNil(Store)
|
||||
asserts.IsType(emptyFunc(), handler)
|
||||
}
|
||||
}
|
||||
|
||||
func emptyFunc() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {}
|
||||
}
|
||||
|
||||
func TestCSRFInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
sessionFunc := Session("233")
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFInit()(c)
|
||||
asserts.True(util.GetSession(c, "CSRF").(bool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFCheck(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
sessionFunc := Session("233")
|
||||
|
||||
// 通过检查
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFInit()(c)
|
||||
CSRFCheck()(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 未通过检查
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFCheck()(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ShareOwner 检查当前登录用户是否为分享所有者
|
||||
func ShareOwner() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var user *model.User
|
||||
if userCtx, ok := c.Get("user"); ok {
|
||||
user = userCtx.(*model.User)
|
||||
} else {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, "", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if share, ok := c.Get("share"); ok {
|
||||
if share.(*model.Share).Creator().ID != user.ID {
|
||||
c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ShareAvailable 检查分享是否可用
|
||||
func ShareAvailable() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var user *model.User
|
||||
if userCtx, ok := c.Get("user"); ok {
|
||||
user = userCtx.(*model.User)
|
||||
} else {
|
||||
user = model.NewAnonymousUser()
|
||||
}
|
||||
|
||||
share := model.GetShareByHashID(c.Param("id"))
|
||||
|
||||
if share == nil || !share.IsAvailable() {
|
||||
c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user", user)
|
||||
c.Set("share", share)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ShareCanPreview 检查分享是否可被预览
|
||||
func ShareCanPreview() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if share, ok := c.Get("share"); ok {
|
||||
if share.(*model.Share).PreviewEnabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
c.JSON(200, serializer.Err(serializer.CodeDisabledSharePreview, "",
|
||||
nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// CheckShareUnlocked 检查分享是否已解锁
|
||||
func CheckShareUnlocked() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if shareCtx, ok := c.Get("share"); ok {
|
||||
share := shareCtx.(*model.Share)
|
||||
// 分享是否已解锁
|
||||
if share.Password != "" {
|
||||
sessionKey := fmt.Sprintf("share_unlock_%d", share.ID)
|
||||
unlocked := util.GetSession(c, sessionKey) != nil
|
||||
if !unlocked {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr,
|
||||
"", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeShareDownload 分享被下载前的检查
|
||||
func BeforeShareDownload() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if shareCtx, ok := c.Get("share"); ok {
|
||||
if userCtx, ok := c.Get("user"); ok {
|
||||
share := shareCtx.(*model.Share)
|
||||
user := userCtx.(*model.User)
|
||||
|
||||
// 检查用户是否可以下载此分享的文件
|
||||
err := share.CanBeDownloadBy(user)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
|
||||
nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 对积分、下载次数进行更新
|
||||
err = share.DownloadBy(user, c)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
|
||||
nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShareAvailable(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := ShareAvailable()
|
||||
|
||||
// 分享不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"id", "empty"},
|
||||
}
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 通过
|
||||
{
|
||||
conf.SystemConfig.HashIDSalt = ""
|
||||
// 用户组
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
|
||||
mock.ExpectQuery("SELECT(.+)shares(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows(
|
||||
[]string{"id", "remain_downloads", "source_id"}).
|
||||
AddRow(1, 1, 2),
|
||||
)
|
||||
mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"id", "x9T4"},
|
||||
}
|
||||
testFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.False(c.IsAborted())
|
||||
asserts.NotNil(c.Get("user"))
|
||||
asserts.NotNil(c.Get("share"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestShareCanPreview(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := ShareCanPreview()
|
||||
|
||||
// 无分享上下文
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 可以预览
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{PreviewEnabled: true})
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 未开启预览
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{PreviewEnabled: false})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckShareUnlocked(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := CheckShareUnlocked()
|
||||
|
||||
// 无分享上下文
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 无密码
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{})
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestBeforeShareDownload(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := BeforeShareDownload()
|
||||
|
||||
// 无分享上下文
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 用户不能下载
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{})
|
||||
c.Set("user", &model.User{
|
||||
Group: model.Group{OptionsSerialized: model.GroupOption{}},
|
||||
})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 可以下载
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{})
|
||||
c.Set("user", &model.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Group: model.Group{OptionsSerialized: model.GroupOption{
|
||||
ShareDownload: true,
|
||||
}},
|
||||
})
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShareOwner(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := ShareOwner()
|
||||
|
||||
// 未登录
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 非用户所创建分享
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{User: model.User{Model: gorm.Model{ID: 1}}})
|
||||
c.Set("user", &model.User{})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 正常
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Set("share", &model.Share{})
|
||||
c.Set("user", &model.User{})
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/wopi"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
WopiSessionCtx = "wopi_session"
|
||||
)
|
||||
|
||||
// WopiWriteAccess validates if write access is obtained.
|
||||
func WopiWriteAccess() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session := c.MustGet(WopiSessionCtx).(*wopi.SessionCache)
|
||||
session := c.MustGet(wopi.WopiSessionCtx).(*wopi.SessionCache)
|
||||
if session.Action != wopi.ActionEdit {
|
||||
c.Status(http.StatusNotFound)
|
||||
c.Header(wopi.ServerErrorHeader, "read-only access")
|
||||
@@ -28,8 +27,12 @@ func WopiWriteAccess() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
|
||||
func ViewerSessionValidation() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
dep := dependency.FromContext(c)
|
||||
store := dep.KV()
|
||||
settings := dep.SettingProvider()
|
||||
|
||||
accessToken := strings.Split(c.Query(wopi.AccessTokenQuery), ".")
|
||||
if len(accessToken) != 2 {
|
||||
c.Status(http.StatusForbidden)
|
||||
@@ -38,7 +41,7 @@ func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
sessionRaw, exist := store.Get(wopi.SessionCachePrefix + accessToken[0])
|
||||
sessionRaw, exist := store.Get(manager.ViewerSessionCachePrefix + accessToken[0])
|
||||
if !exist {
|
||||
c.Status(http.StatusForbidden)
|
||||
c.Header(wopi.ServerErrorHeader, "invalid access token")
|
||||
@@ -46,25 +49,47 @@ func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
session := sessionRaw.(wopi.SessionCache)
|
||||
user, err := model.GetActiveUserByID(session.UserID)
|
||||
if err != nil {
|
||||
session := sessionRaw.(manager.ViewerSessionCache)
|
||||
if err := SetUserCtx(c, session.UserID); err != nil {
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Header(wopi.ServerErrorHeader, "user not found")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
fileID := c.MustGet("object_id").(uint)
|
||||
if fileID != session.FileID {
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Header(wopi.ServerErrorHeader, "file not found")
|
||||
fileId := hashid.FromContext(c)
|
||||
if fileId != session.FileID {
|
||||
c.Status(http.StatusForbidden)
|
||||
c.Header(wopi.ServerErrorHeader, "invalid file")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user", &user)
|
||||
c.Set(WopiSessionCtx, &session)
|
||||
// Check if the viewer is still available
|
||||
viewers := settings.FileViewers(c)
|
||||
var v *setting.Viewer
|
||||
for _, group := range viewers {
|
||||
for _, viewer := range group.Viewers {
|
||||
if viewer.ID == session.ViewerID && !viewer.Disabled {
|
||||
v = &viewer
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if v == nil {
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Header(wopi.ServerErrorHeader, "viewer not found")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
util.WithValue(c, manager.ViewerCtx{}, v)
|
||||
util.WithValue(c, manager.ViewerSessionCacheCtx{}, &session)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWopiWriteAccess(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
testFunc := WopiWriteAccess()
|
||||
|
||||
// deny preview only session
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview})
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// pass
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit})
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWopiAccessValidation(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
mockWopi := &wopimock.WopiClientMock{}
|
||||
mockCache := cache.NewMemoStore()
|
||||
testFunc := WopiAccessValidation(mockWopi, mockCache)
|
||||
|
||||
// malformed access token
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.AddParam(wopi.AccessTokenQuery, "000")
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// session key not exist
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
|
||||
query := c.Request.URL.Query()
|
||||
query.Set(wopi.AccessTokenQuery, "sessionID.key")
|
||||
c.Request.URL.RawQuery = query.Encode()
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// user key not exist
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
|
||||
query := c.Request.URL.Query()
|
||||
query.Set(wopi.AccessTokenQuery, "sessionID.key")
|
||||
c.Request.URL.RawQuery = query.Encode()
|
||||
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// file not found
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
|
||||
query := c.Request.URL.Query()
|
||||
query.Set(wopi.AccessTokenQuery, "sessionID.key")
|
||||
c.Request.URL.RawQuery = query.Encode()
|
||||
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
c.Set("object_id", uint(0))
|
||||
testFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// all pass
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
|
||||
query := c.Request.URL.Query()
|
||||
query.Set(wopi.AccessTokenQuery, "sessionID.key")
|
||||
c.Request.URL.RawQuery = query.Encode()
|
||||
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
c.Set("object_id", uint(1))
|
||||
testFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NotPanics(func() {
|
||||
c.MustGet(WopiSessionCtx)
|
||||
})
|
||||
asserts.NotPanics(func() {
|
||||
c.MustGet("user")
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user