Init V4 community edition (#2265)

* Init V4 community edition

* Init V4 community edition
This commit is contained in:
AaronLiu
2025-04-20 17:31:25 +08:00
committed by GitHub
parent da4e44b77a
commit 21d158db07
597 changed files with 119415 additions and 41692 deletions

View File

@@ -1,24 +1,21 @@
package middleware
import (
"bytes"
"context"
"crypto/md5"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-contrib/sessions"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/oss"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/gin-gonic/gin"
)
@@ -31,14 +28,14 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
return func(c *gin.Context) {
var err error
switch c.Request.Method {
case "PUT", "POST", "PATCH":
err = auth.CheckRequest(authInstance, c.Request)
case http.MethodPut, http.MethodPost, http.MethodPatch:
err = auth.CheckRequest(c, authInstance, c.Request)
default:
err = auth.CheckURI(authInstance, c.Request.URL)
err = auth.CheckURI(c, authInstance, c.Request.URL)
}
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err))
c.Abort()
return
}
@@ -50,29 +47,55 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
// CurrentUser 获取登录用户
func CurrentUser() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
uid := session.Get("user_id")
if uid != nil {
user, err := model.GetActiveUserByID(uid)
if err == nil {
c.Set("user", &user)
}
dep := dependency.FromContext(c)
shouldContinue, err := dep.TokenAuth().VerifyAndRetrieveUser(c)
if err != nil {
c.JSON(200, serializer.Err(c, err))
c.Abort()
return
}
if shouldContinue {
// TODO: Logto handler
}
uid := inventory.UserIDFromContext(c)
if err := SetUserCtx(c, uid); err != nil {
c.JSON(200, serializer.Err(c, err))
c.Abort()
return
}
c.Next()
}
}
// AuthRequired 需要登录
func AuthRequired() gin.HandlerFunc {
// SetUserCtx set the current login user via uid
func SetUserCtx(c *gin.Context, uid int) error {
dep := dependency.FromContext(c)
userClient := dep.UserClient()
loginUser, err := userClient.GetLoginUserByID(c, uid)
if err != nil {
return serializer.NewError(serializer.CodeDBError, "failed to get login user", err)
}
SetUserCtxByUser(c, loginUser)
return nil
}
func SetUserCtxByUser(c *gin.Context, user *ent.User) {
util.WithValue(c, inventory.UserCtx{}, user)
}
// LoginRequired 需要登录
func LoginRequired() gin.HandlerFunc {
return func(c *gin.Context) {
if user, _ := c.Get("user"); user != nil {
if _, ok := user.(*model.User); ok {
c.Next()
return
}
if u := inventory.UserFromContext(c); u != nil && !inventory.IsAnonymousUser(u) {
c.Next()
return
}
c.JSON(200, serializer.CheckLogin())
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCheckLogin, "Login required", nil))
c.Abort()
}
}
@@ -80,60 +103,84 @@ func AuthRequired() gin.HandlerFunc {
// WebDAVAuth 验证WebDAV登录及权限
func WebDAVAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// OPTIONS 请求不需要鉴权否则Windows10下无法保存文档
if c.Request.Method == "OPTIONS" {
c.Next()
return
}
username, password, ok := c.Request.BasicAuth()
if !ok {
// OPTIONS 请求不需要鉴权
if c.Request.Method == http.MethodOptions {
c.Next()
return
}
c.Writer.Header()["WWW-Authenticate"] = []string{`Basic realm="cloudreve"`}
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
expectedUser, err := model.GetActiveUserByEmail(username)
dep := dependency.FromContext(c)
l := dep.Logger()
userClient := dep.UserClient()
expectedUser, err := userClient.GetActiveByDavAccount(c, username, password)
if err != nil {
if username == "" {
if u, err := userClient.GetByEmail(c, username); err == nil {
// Try login with known user but incorrect password, record audit log
SetUserCtxByUser(c, u)
}
}
l.Debug("WebDAVAuth: failed to get user %q with provided credential: %s", username, err)
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
// 密码正确?
webdav, err := model.GetWebdavByPassword(password, expectedUser.ID)
if err != nil {
// Validate dav account
accounts, err := expectedUser.Edges.DavAccountsOrErr()
if err != nil || len(accounts) == 0 {
l.Debug("WebDAVAuth: failed to get user dav accounts %q with provided credential: %s", username, err)
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
// 用户组已启用WebDAV
if !expectedUser.Group.WebDAVEnabled {
c.Status(http.StatusForbidden)
group, err := expectedUser.Edges.GroupOrErr()
if err != nil {
l.Debug("WebDAVAuth: user group not found: %s", err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
// 用户组已启用WebDAV代理
if !expectedUser.Group.OptionsSerialized.WebDAVProxy {
webdav.UseProxy = false
if !group.Permissions.Enabled(int(types.GroupPermissionWebDAV)) {
c.Status(http.StatusForbidden)
l.Debug("WebDAVAuth: user %q does not have WebDAV permission.", expectedUser.Email)
c.Abort()
return
}
c.Set("user", &expectedUser)
c.Set("webdav", webdav)
// 检查是否只读
if expectedUser.Edges.DavAccounts[0].Options.Enabled(int(types.DavAccountReadOnly)) {
switch c.Request.Method {
case http.MethodDelete, http.MethodPut, "MKCOL", "COPY", "MOVE", "LOCK", "UNLOCK":
c.Status(http.StatusForbidden)
c.Abort()
return
}
}
SetUserCtxByUser(c, expectedUser)
c.Next()
}
}
// 对上传会话进行验证
func UseUploadSession(policyType string) gin.HandlerFunc {
func UseUploadSession(policyType types.PolicyType) gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp := uploadCallbackCheck(c, policyType)
if resp.Code != 0 {
c.JSON(CallbackFailedStatusCode, resp)
err := uploadCallbackCheck(c, policyType)
if err != nil {
c.JSON(CallbackFailedStatusCode, serializer.Err(c, err))
c.Abort()
return
}
@@ -143,44 +190,46 @@ func UseUploadSession(policyType string) gin.HandlerFunc {
}
// uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户
func uploadCallbackCheck(c *gin.Context, policyType string) serializer.Response {
func uploadCallbackCheck(c *gin.Context, policyType types.PolicyType) error {
// 验证 Callback Key
sessionID := c.Param("sessionID")
if sessionID == "" {
return serializer.ParamErr("Session ID cannot be empty", nil)
return serializer.NewError(serializer.CodeParamErr, "Session ID cannot be empty", nil)
}
callbackSessionRaw, exist := cache.Get(filesystem.UploadSessionCachePrefix + sessionID)
dep := dependency.FromContext(c)
callbackSessionRaw, exist := dep.KV().Get(manager.UploadSessionCachePrefix + sessionID)
if !exist {
return serializer.Err(serializer.CodeUploadSessionExpired, "上传会话不存在或已过期", nil)
return serializer.NewError(serializer.CodeUploadSessionExpired, "Upload session does not exist or expired", nil)
}
callbackSession := callbackSessionRaw.(serializer.UploadSession)
c.Set(filesystem.UploadSessionCtx, &callbackSession)
if callbackSession.Policy.Type != policyType {
return serializer.Err(serializer.CodePolicyNotAllowed, "", nil)
callbackSession := callbackSessionRaw.(fs.UploadSession)
c.Set(manager.UploadSessionCtx, &callbackSession)
if callbackSession.Policy.Type != string(policyType) {
return serializer.NewError(serializer.CodePolicyNotAllowed, "", nil)
}
// 清理回调会话
_ = cache.Deletes([]string{sessionID}, filesystem.UploadSessionCachePrefix)
// 查找用户
user, err := model.GetActiveUserByID(callbackSession.UID)
if err != nil {
return serializer.Err(serializer.CodeUserNotFound, "", err)
if err := SetUserCtx(c, callbackSession.UID); err != nil {
return err
}
c.Set(filesystem.UserCtx, &user)
return serializer.Response{}
return nil
}
// RemoteCallbackAuth 远程回调签名验证
func RemoteCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证签名
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.SecretKey)}
if err := auth.CheckRequest(authInstance, c.Request); err != nil {
c.JSON(CallbackFailedStatusCode, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
session := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession)
if session.Policy.Edges.Node == nil {
c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, "Node not found", nil))
c.Abort()
return
}
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.Edges.Node.SlaveKey)}
if err := auth.CheckRequest(c, authInstance, c.Request); err != nil {
c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err))
c.Abort()
return
}
@@ -190,37 +239,16 @@ func RemoteCallbackAuth() gin.HandlerFunc {
}
}
// QiniuCallbackAuth 七牛回调签名验证
func QiniuCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
// 验证回调是否来自qiniu
mac := qbox.NewMac(session.Policy.AccessKey, session.Policy.SecretKey)
ok, err := mac.VerifyCallback(c.Request)
if err != nil {
util.Log().Debug("Failed to verify callback request: %s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
c.Abort()
return
}
if !ok {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Invalid signature."})
c.Abort()
return
}
c.Next()
}
}
// OSSCallbackAuth 阿里云OSS回调签名验证
func OSSCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
err := oss.VerifyCallbackSignature(c.Request)
dep := dependency.FromContext(c)
err := oss.VerifyCallbackSignature(c.Request, dep.KV(), dep.RequestClient(
request.WithContext(c),
request.WithLogger(logging.FromContext(c)),
))
if err != nil {
util.Log().Debug("Failed to verify callback request: %s", err)
dep.Logger().Debug("Failed to verify callback request: %s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
c.Abort()
return
@@ -230,71 +258,12 @@ func OSSCallbackAuth() gin.HandlerFunc {
}
}
// UpyunCallbackAuth 又拍云回调签名验证
func UpyunCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
// 获取请求正文
body, err := ioutil.ReadAll(c.Request.Body)
c.Request.Body.Close()
if err != nil {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()})
c.Abort()
return
}
c.Request.Body = ioutil.NopCloser(bytes.NewReader(body))
// 准备验证Upyun回调签名
handler := upyun.Driver{Policy: &session.Policy}
contentMD5 := c.Request.Header.Get("Content-Md5")
date := c.Request.Header.Get("Date")
actualSignature := c.Request.Header.Get("Authorization")
// 计算正文MD5
actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body))
if actualContentMD5 != contentMD5 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5 mismatch."})
c.Abort()
return
}
// 计算理论签名
signature := handler.Sign(context.Background(), []string{
"POST",
c.Request.URL.Path,
date,
contentMD5,
})
// 对比签名
if signature != actualSignature {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Signature not match"})
c.Abort()
return
}
c.Next()
}
}
// OneDriveCallbackAuth OneDrive回调签名验证
func OneDriveCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 发送回调结束信号
mq.GlobalMQ.Publish(c.Param("sessionID"), mq.Message{})
c.Next()
}
}
// IsAdmin 必须为管理员用户组
func IsAdmin() gin.HandlerFunc {
return func(c *gin.Context) {
user, _ := c.Get("user")
if user.(*model.User).Group.ID != 1 && user.(*model.User).ID != 1 {
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "", nil))
user := inventory.UserFromContext(c)
if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeNoPermissionErr, "", nil))
c.Abort()
return
}

View File

@@ -1,605 +0,0 @@
package middleware
import (
"database/sql"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestCurrentUser(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
//session为空
sessionFunc := Session("233")
sessionFunc(c)
CurrentUser()(c)
user, _ := c.Get("user")
asserts.Nil(user)
//session正确
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
util.SetSession(c, map[string]interface{}{"user_id": 1})
rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}).
AddRow(1, nil, "admin@cloudreve.org", "{}")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows)
CurrentUser()(c)
user, _ = c.Get("user")
asserts.NotNil(user)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestAuthRequired(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
AuthRequiredFunc := AuthRequired()
// 未登录
AuthRequiredFunc(c)
asserts.NotNil(c)
// 类型错误
c.Set("user", 123)
AuthRequiredFunc(c)
asserts.NotNil(c)
// 正常
c.Set("user", &model.User{})
AuthRequiredFunc(c)
asserts.NotNil(c)
}
func TestSignRequired(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}
func TestWebDAVAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := WebDAVAuth()
// options请求跳过验证
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
AuthFunc(c)
}
// 请求HTTP Basic Auth
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
AuthFunc(c)
asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
}
// 用户名不存在
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email"}),
)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
// 密码错误
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
)
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
//未启用 WebDAV
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusForbidden)
}
//正常
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), 200)
_, ok := c.Get("user")
asserts.True(ok)
}
}
func TestUseUploadSession(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UseUploadSession("local")
// sessionID 为空
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
filesystem.UploadSessionCachePrefix+"testCallBackRemote",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "local"},
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestUploadCallbackCheck(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
// 上传会话不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testSessionNotExist"},
}
res := uploadCallbackCheck(c, "local")
a.Contains("上传会话不存在或已过期", res.Msg)
}
// 上传策略不一致
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testPolicyNotMatch"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testPolicyNotMatch",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
res := uploadCallbackCheck(c, "local")
a.Contains("Policy not supported", res.Msg)
}
// 用户不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testUserNotExist"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testUserNotExist",
serializer.UploadSession{
UID: 313,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
res := uploadCallbackCheck(c, "remote")
a.Contains("找不到用户", res.Msg)
a.NoError(mock.ExpectationsWereMet())
_, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist")
a.False(ok)
}
}
func TestRemoteCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := RemoteCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.False(c.IsAborted())
}
// 签名错误
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
func TestQiniuCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := QiniuCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
// 验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "1213")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
func TestOSSCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := OSSCallbackAuth()
// 签名验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`)))
c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="}
c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="}
AuthFunc(c)
asserts.False(c.IsAborted())
}
}
type fakeRead string
func (r fakeRead) Read(p []byte) (int, error) {
return 0, errors.New("error")
}
func TestUpyunCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UpyunCallbackAuth()
// 无法获取请求正文
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead("")))
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 正文MD5不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"123"}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 签名不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="}
AuthFunc(c)
asserts.False(c.IsAborted())
}
}
func TestOneDriveCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := OneDriveCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "TestOneDriveCallbackAuth"},
}
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1")))
res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1)
AuthFunc(c)
select {
case <-res:
case <-time.After(time.Millisecond * 500):
asserts.Fail("mq message should be published")
}
asserts.False(c.IsAborted())
}
}
func TestIsAdmin(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := IsAdmin()
// 非管理员
{
c, _ := gin.CreateTestContext(rec)
c.Set("user", &model.User{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 是管理员
{
c, _ := gin.CreateTestContext(rec)
user := &model.User{}
user.Group.ID = 1
c.Set("user", user)
testFunc(c)
asserts.False(c.IsAborted())
}
// 初始用户,非管理组
{
c, _ := gin.CreateTestContext(rec)
user := &model.User{}
user.Group.ID = 2
user.ID = 1
c.Set("user", user)
testFunc(c)
asserts.False(c.IsAborted())
}
}

View File

@@ -3,52 +3,56 @@ package middleware
import (
"bytes"
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/recaptcha"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/recaptcha"
request2 "github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/gin-gonic/gin"
"github.com/mojocn/base64Captcha"
captcha "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha/v20190722"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
"io"
"io/ioutil"
"strconv"
"net/http"
"net/url"
"strings"
"time"
)
type req struct {
CaptchaCode string `json:"captchaCode"`
Ticket string `json:"ticket"`
Randstr string `json:"randstr"`
Captcha string `json:"captcha"`
Ticket string `json:"ticket"`
Randstr string `json:"randstr"`
}
const (
captchaNotMatch = "CAPTCHA not match."
captchaRefresh = "Verification failed, please refresh the page and retry."
tcCaptchaEndpoint = "captcha.tencentcloudapi.com"
turnstileEndpoint = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
)
// CaptchaIDCtx defines keys for captcha ID
type (
CaptchaIDCtx struct{}
turnstileResponse struct {
Success bool `json:"success"`
}
)
// CaptchaRequired 验证请求签名
func CaptchaRequired(configName string) gin.HandlerFunc {
func CaptchaRequired(enabled func(c *gin.Context) bool) gin.HandlerFunc {
return func(c *gin.Context) {
// 相关设定
options := model.GetSettingByNames(configName,
"captcha_type",
"captcha_ReCaptchaSecret",
"captcha_TCaptcha_SecretId",
"captcha_TCaptcha_SecretKey",
"captcha_TCaptcha_CaptchaAppId",
"captcha_TCaptcha_AppSecretKey")
// 检查验证码
isCaptchaRequired := model.IsTrueVal(options[configName])
if enabled(c) {
dep := dependency.FromContext(c)
settings := dep.SettingProvider()
l := logging.FromContext(c)
if isCaptchaRequired {
var service req
bodyCopy := new(bytes.Buffer)
_, err := io.Copy(bodyCopy, c.Request.Body)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err))
c.Abort()
return
}
@@ -56,65 +60,69 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
bodyData := bodyCopy.Bytes()
err = json.Unmarshal(bodyData, &service)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err))
c.Abort()
return
}
c.Request.Body = ioutil.NopCloser(bytes.NewReader(bodyData))
switch options["captcha_type"] {
case "normal":
captchaID := util.GetSession(c, "captchaID")
util.DeleteSession(c, "captchaID")
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
c.Request.Body = io.NopCloser(bytes.NewReader(bodyData))
switch settings.CaptchaType(c) {
case setting.CaptchaNormal, setting.CaptchaTcaptcha:
if service.Ticket == "" || !base64Captcha.VerifyCaptcha(service.Ticket, service.Captcha) {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaNotMatch, err))
c.Abort()
return
}
break
case "recaptcha":
reCAPTCHA, err := recaptcha.NewReCAPTCHA(options["captcha_ReCaptchaSecret"], recaptcha.V2, 10*time.Second)
case setting.CaptchaReCaptcha:
captchaSetting := settings.ReCaptcha(c)
reCAPTCHA, err := recaptcha.NewReCAPTCHA(captchaSetting.Secret, recaptcha.V2, 10*time.Second)
if err != nil {
util.Log().Warning("reCAPTCHA verification failed, %s", err)
l.Warning("reCAPTCHA verification failed, %s", err)
c.Abort()
break
}
err = reCAPTCHA.Verify(service.CaptchaCode)
err = reCAPTCHA.Verify(service.Captcha)
if err != nil {
util.Log().Warning("reCAPTCHA verification failed, %s", err)
c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil))
l.Warning("reCAPTCHA verification failed, %s", err)
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, captchaRefresh, err))
c.Abort()
return
}
break
case "tcaptcha":
credential := common.NewCredential(
options["captcha_TCaptcha_SecretId"],
options["captcha_TCaptcha_SecretKey"],
case setting.CaptchaTurnstile:
captchaSetting := settings.TurnstileCaptcha(c)
r := dep.RequestClient(
request2.WithContext(c),
request2.WithLogger(logging.FromContext(c)),
request2.WithHeader(http.Header{"Content-Type": []string{"application/x-www-form-urlencoded"}}),
)
cpf := profile.NewClientProfile()
cpf.HttpProfile.Endpoint = "captcha.tencentcloudapi.com"
client, _ := captcha.NewClient(credential, "", cpf)
request := captcha.NewDescribeCaptchaResultRequest()
request.CaptchaType = common.Uint64Ptr(9)
appid, _ := strconv.Atoi(options["captcha_TCaptcha_CaptchaAppId"])
request.CaptchaAppId = common.Uint64Ptr(uint64(appid))
request.AppSecretKey = common.StringPtr(options["captcha_TCaptcha_AppSecretKey"])
request.Ticket = common.StringPtr(service.Ticket)
request.Randstr = common.StringPtr(service.Randstr)
request.UserIp = common.StringPtr(c.ClientIP())
response, err := client.DescribeCaptchaResult(request)
formData := url.Values{}
formData.Set("secret", captchaSetting.Secret)
formData.Set("response", service.Ticket)
res, err := r.Request("POST", turnstileEndpoint, strings.NewReader(formData.Encode())).
CheckHTTPResponse(http.StatusOK).
GetResponse()
if err != nil {
util.Log().Warning("TCaptcha verification failed, %s", err)
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err))
c.Abort()
break
return
}
if *response.Response.CaptchaCode != int64(1) {
c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil))
var trunstileRes turnstileResponse
err = json.Unmarshal([]byte(res), &trunstileRes)
if err != nil {
l.Warning("Turnstile verification failed, %s", err)
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err))
c.Abort()
return
}
if !trunstileRes.Success {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCaptchaError, "Captcha validation failed", err))
c.Abort()
return
}

View File

@@ -1,177 +0,0 @@
package middleware
import (
"bytes"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
type errReader int
func (errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("test error")
}
func TestCaptchaRequired_General(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 未启用验证码
{
cache.SetSettings(map[string]string{
"login_captcha": "0",
"captcha_type": "1",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// body 无法读取
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "1",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", errReader(1))
TestFunc(c)
asserts.True(c.IsAborted())
}
// body JSON 解析失败
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "1",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("123"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCaptchaRequired_Normal(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 验证码错误
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "normal",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
Session("233")(c)
TestFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCaptchaRequired_Recaptcha(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 无法初始化reCaptcha实例
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "recaptcha",
"captcha_ReCaptchaSecret": "",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
// 验证码错误
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "recaptcha",
"captcha_ReCaptchaSecret": "233",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCaptchaRequired_Tcaptcha(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 验证出错
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "tcaptcha",
"captcha_ReCaptchaSecret": "",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
}

View File

@@ -1,62 +1,75 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"sync"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
"github.com/cloudreve/Cloudreve/v4/pkg/downloader"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/routers/controllers"
"github.com/gin-gonic/gin"
"strconv"
)
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
func MasterMetadata() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("MasterSiteID", c.GetHeader(auth.CrHeaderPrefix+"Site-Id"))
c.Set("MasterSiteURL", c.GetHeader(auth.CrHeaderPrefix+"Site-Url"))
c.Set("MasterVersion", c.GetHeader(auth.CrHeaderPrefix+"Cloudreve-Version"))
c.Next()
}
type SlaveNodeSettingGetter interface {
// GetNodeSetting returns the node settings and its hash
GetNodeSetting() (*types.NodeSetting, string)
}
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
return func(c *gin.Context) {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := clusterController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "Failed to get Aria2 instance", err))
c.Abort()
return
}
var downloaderPool = sync.Map{}
c.Set("MasterAria2Instance", caller)
// PrepareSlaveDownloader creates or resume a downloader based on input node settings
func PrepareSlaveDownloader(dep dependency.Dep, ctxKey interface{}) gin.HandlerFunc {
return func(c *gin.Context) {
nodeSettings, hash := controllers.ParametersFromContext[SlaveNodeSettingGetter](c, ctxKey).GetNodeSetting()
// try to get downloader from pool
if d, ok := downloaderPool.Load(hash); ok {
c.Set(downloader.DownloaderCtxKey, d)
c.Next()
return
}
c.JSON(200, serializer.ParamErr("Unknown master node ID", nil))
c.Abort()
}
}
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader(auth.CrHeaderPrefix+"Node-Id"), 10, 64)
// create a new downloader
d, err := cluster.NewDownloader(c, dep.RequestClient(request.WithContext(c), request.WithLogger(dep.Logger())), dep.SettingProvider(), nodeSettings)
if err != nil {
c.JSON(200, serializer.ParamErr("Unknown master node ID", err))
c.JSON(200, serializer.ParamErr(c, "Failed to create downloader", err))
c.Abort()
return
}
slaveNode := nodePool.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("Unknown master node ID", err))
c.Abort()
return
}
SignRequired(slaveNode.MasterAuthInstance())(c)
// save downloader to pool
downloaderPool.Store(hash, d)
c.Set(downloader.DownloaderCtxKey, d)
c.Next()
}
}
func SlaveRPCSignRequired() gin.HandlerFunc {
return func(c *gin.Context) {
nodeId := cluster.NodeIdFromContext(c)
if nodeId == 0 {
c.JSON(200, serializer.ParamErr(c, "Unknown node ID", nil))
c.Abort()
return
}
np, err := dependency.FromContext(c).NodePool(c)
if err != nil {
c.JSON(200, serializer.NewError(serializer.CodeInternalSetting, "Failed to get node pool", err))
c.Abort()
return
}
slaveNode, err := np.Get(c, types.NodeCapabilityNone, nodeId)
if slaveNode == nil || slaveNode.IsMaster() {
c.JSON(200, serializer.ParamErr(c, "Unknown node ID", err))
c.Abort()
return
}
SignRequired(slaveNode.AuthInstance())(c)
}
}

View File

@@ -1,120 +0,0 @@
package middleware
import (
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestMasterMetadata(t *testing.T) {
a := assert.New(t)
masterMetaDataFunc := MasterMetadata()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header = map[string][]string{
"X-Cr-Site-Id": {"expectedSiteID"},
"X-Cr-Site-Url": {"expectedSiteURL"},
"X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
}
masterMetaDataFunc(c)
siteID, _ := c.Get("MasterSiteID")
siteURL, _ := c.Get("MasterSiteURL")
siteVersion, _ := c.Get("MasterVersion")
a.Equal("expectedSiteID", siteID.(string))
a.Equal("expectedSiteURL", siteURL.(string))
a.Equal("expectedMasterVersion", siteVersion.(string))
}
func TestSlaveRPCSignRequired(t *testing.T) {
a := assert.New(t)
np := &cluster.NodePool{}
np.Init()
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
rec := httptest.NewRecorder()
// id parse failed
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "unknown")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// node id not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "38")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// success
{
authInstance := auth.HMACAuth{SecretKey: []byte("")}
np.Add(&model.Node{Model: gorm.Model{
ID: 38,
}})
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "38")
c.Request = auth.SignRequest(authInstance, c.Request, 0)
slaveRPCSignRequiredFunc(c)
a.False(c.IsAborted())
}
}
func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)
// MasterSiteID not set
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
useSlaveAria2InstanceFunc(c)
a.True(c.IsAborted())
}
// Cannot get aria2 instances
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("MasterSiteID", "expectedSiteID")
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
useSlaveAria2InstanceFunc(c)
a.True(c.IsAborted())
testController.AssertExpectations(t)
}
// Success
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("MasterSiteID", "expectedSiteID")
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
useSlaveAria2InstanceFunc(c)
a.False(c.IsAborted())
res, _ := c.Get("MasterAria2Instance")
a.NotNil(res)
testController.AssertExpectations(t)
}
}

View File

@@ -1,26 +1,35 @@
package middleware
import (
"context"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/auth/requestinfo"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"net/http"
"time"
)
// HashID 将给定对象的HashID转换为真实ID
func HashID(IDType int) gin.HandlerFunc {
return func(c *gin.Context) {
dep := dependency.FromContext(c)
if c.Param("id") != "" {
id, err := hashid.DecodeHashID(c.Param("id"), IDType)
id, err := dep.HashIDEncoder().Decode(c.Param("id"), IDType)
if err == nil {
c.Set("object_id", id)
util.WithValue(c, hashid.ObjectIDCtx{}, id)
c.Next()
return
}
c.JSON(200, serializer.ParamErr("Failed to parse object ID", nil))
c.JSON(200, serializer.ParamErr(c, "Failed to parse object ID", err))
c.Abort()
return
@@ -30,10 +39,10 @@ func HashID(IDType int) gin.HandlerFunc {
}
// IsFunctionEnabled 当功能未开启时阻止访问
func IsFunctionEnabled(key string) gin.HandlerFunc {
func IsFunctionEnabled(check func(c *gin.Context) bool) gin.HandlerFunc {
return func(c *gin.Context) {
if !model.IsTrueVal(model.GetSettingByName(key)) {
c.JSON(200, serializer.Err(serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil))
if !check(c) {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil))
c.Abort()
return
}
@@ -56,9 +65,10 @@ func Sandbox() gin.HandlerFunc {
}
// StaticResourceCache 使用静态资源缓存策略
func StaticResourceCache() gin.HandlerFunc {
func StaticResourceCache(dep dependency.Dep) gin.HandlerFunc {
settings := dep.SettingProvider()
return func(c *gin.Context) {
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400)))
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", settings.PublicResourceMaxAge(c)))
}
}
@@ -66,8 +76,9 @@ func StaticResourceCache() gin.HandlerFunc {
// MobileRequestOnly
func MobileRequestOnly() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader(auth.CrHeaderPrefix+"ios") == "" {
c.Redirect(http.StatusMovedPermanently, model.GetSiteURL().String())
dep := dependency.FromContext(c)
if c.GetHeader(constants.CrHeaderPrefix+"ios") == "" {
c.Redirect(http.StatusMovedPermanently, dep.SettingProvider().SiteURL(c).String())
c.Abort()
return
}
@@ -75,3 +86,66 @@ func MobileRequestOnly() gin.HandlerFunc {
c.Next()
}
}
// InitializeHandling is added at the beginning of handler chain, it did following setups:
// 1. Inject dependency manager into request context
// 2. Generate and inject correlation ID for diagnostic.
func InitializeHandling(dep dependency.Dep) gin.HandlerFunc {
return func(c *gin.Context) {
reqInfo := &requestinfo.RequestInfo{
IP: c.ClientIP(),
Host: c.Request.Host,
UserAgent: c.Request.UserAgent(),
}
cid := uuid.FromStringOrNil(c.GetHeader(request.CorrelationHeader))
if cid == uuid.Nil {
cid = uuid.Must(uuid.NewV4())
}
l := dep.Logger().CopyWithPrefix(fmt.Sprintf("[Cid: %s]", cid))
ctx := dep.ForkWithLogger(c.Request.Context(), l)
ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, cid)
ctx = context.WithValue(ctx, requestinfo.RequestInfoCtx{}, reqInfo)
ctx = context.WithValue(ctx, logging.LoggerCtx{}, l)
if id := c.Param("nodeId"); id != "" {
ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, id)
} else {
ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, c.GetHeader(request.SlaveNodeIDHeader))
}
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// InitializeHandlingSlave retrieves coll correlation ID and other metadata from request header
func InitializeHandlingSlave() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := context.WithValue(c.Request.Context(), cluster.MasterSiteIDCtx{}, c.GetHeader(request.SiteIDHeader))
ctx = context.WithValue(ctx, cluster.MasterSiteUrlCtx{}, c.GetHeader(request.SiteURLHeader))
ctx = context.WithValue(ctx, cluster.MasterSiteVersionCtx{}, c.GetHeader(request.SiteVersionHeader))
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// Logging logs incoming request info
func Logging() gin.HandlerFunc {
return func(c *gin.Context) {
// Start timer
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// Process request
c.Next()
if raw != "" {
path = path + "?" + raw
}
l := logging.FromContext(c)
logging.Request(l, true, c.Writer.Status(), c.Request.Method, c.ClientIP(), path,
c.Errors.ByType(gin.ErrorTypePrivate).String(), start)
}
}

View File

@@ -1,105 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestHashID(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
TestFunc := HashID(hashid.FolderID)
// 未给定ID对象跳过
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
// 给定ID解析失败
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", "2333"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 给定ID解析成功
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", hashid.HashID(1, hashid.FolderID)},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestIsFunctionEnabled(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
TestFunc := IsFunctionEnabled("TestIsFunctionEnabled")
// 未开启
{
cache.Set("setting_TestIsFunctionEnabled", "0", 0)
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.True(c.IsAborted())
}
// 开启
{
cache.Set("setting_TestIsFunctionEnabled", "1", 0)
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
}
func TestCacheControl(t *testing.T) {
a := assert.New(t)
TestFunc := CacheControl()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
}
func TestSandbox(t *testing.T) {
a := assert.New(t)
TestFunc := Sandbox()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
}
func TestStaticResourceCache(t *testing.T) {
a := assert.New(t)
TestFunc := StaticResourceCache()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
}

View File

@@ -1,30 +1,49 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/dbfs"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/cloudreve/Cloudreve/v4/routers/controllers"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
)
// ValidateSourceLink validates if the perm source link is a valid redirect link
func ValidateSourceLink() gin.HandlerFunc {
// UrisService is a wrapper for service supports batch file operations
type UrisService interface {
GetUris() []string
}
// ValidateBatchFileCount validates if the batch file count is within the limit
func ValidateBatchFileCount(dep dependency.Dep, ctxKey interface{}) gin.HandlerFunc {
settings := dep.SettingProvider()
return func(c *gin.Context) {
linkID, ok := c.Get("object_id")
if !ok {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
uris := controllers.ParametersFromContext[UrisService](c, ctxKey)
limit := settings.MaxBatchedFile(c)
if len((uris).GetUris()) > limit {
c.JSON(200, serializer.ErrWithDetails(
c,
serializer.CodeTooManyUris,
fmt.Sprintf("Maximum allowed batch size: %d", limit),
nil,
))
c.Abort()
return
}
sourceLink, err := model.GetSourceLinkByID(linkID)
if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink.Downloaded()
c.Set("source_link", sourceLink)
c.Next()
}
}
// ContextHint parses the context hint header and set it to context
func ContextHint() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader(dbfs.ContextHintHeader) != "" {
util.WithValue(c, dbfs.ContextHintCtxKey{}, uuid.FromStringOrNil(c.GetHeader(dbfs.ContextHintHeader)))
}
c.Next()
}
}

View File

@@ -1,57 +0,0 @@
package middleware
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestValidateSourceLink(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ValidateSourceLink()
// ID 不存在
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
a.True(c.IsAborted())
}
// SourceLink 不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 原文件不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
testFunc(c)
a.False(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
}

View File

@@ -1,73 +1,82 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/bootstrap"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-gonic/gin"
"io/ioutil"
"io"
"net/http"
"strings"
)
// FrontendFileHandler 前端静态文件处理
func FrontendFileHandler() gin.HandlerFunc {
func FrontendFileHandler(dep dependency.Dep) gin.HandlerFunc {
fs := dep.ServerStaticFS()
l := dep.Logger()
ignoreFunc := func(c *gin.Context) {
c.Next()
}
if bootstrap.StaticFS == nil {
if fs == nil {
return ignoreFunc
}
// 读取index.html
file, err := bootstrap.StaticFS.Open("/index.html")
file, err := fs.Open("/index.html")
if err != nil {
util.Log().Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.")
l.Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.")
return ignoreFunc
}
fileContentBytes, err := ioutil.ReadAll(file)
fileContentBytes, err := io.ReadAll(file)
if err != nil {
util.Log().Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.")
l.Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.")
return ignoreFunc
}
fileContent := string(fileContentBytes)
fileServer := http.FileServer(bootstrap.StaticFS)
fileServer := http.FileServer(fs)
return func(c *gin.Context) {
path := c.Request.URL.Path
// API 跳过
// Skipping routers handled by backend
if strings.HasPrefix(path, "/api") ||
strings.HasPrefix(path, "/custom") ||
strings.HasPrefix(path, "/dav") ||
strings.HasPrefix(path, "/f") ||
strings.HasPrefix(path, "/f/") ||
strings.HasPrefix(path, "/s/") ||
path == "/manifest.json" {
c.Next()
return
}
// 不存在的路径和index.html均返回index.html
if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) {
if (path == "/index.html") || (path == "/") || !fs.Exists("/", path) {
// 读取、替换站点设置
options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript",
"pwa_small_icon")
settingClient := dep.SettingProvider()
siteBasic := settingClient.SiteBasic(c)
pwaOpts := settingClient.PWA(c)
theme := settingClient.Theme(c)
finalHTML := util.Replace(map[string]string{
"{siteName}": options["siteName"],
"{siteDes}": options["siteDes"],
"{siteScript}": options["siteScript"],
"{pwa_small_icon}": options["pwa_small_icon"],
"{siteName}": siteBasic.Name,
"{siteDes}": siteBasic.Description,
"{siteScript}": siteBasic.Script,
"{pwa_small_icon}": pwaOpts.SmallIcon,
"{pwa_medium_icon}": pwaOpts.MediumIcon,
"var(--defaultThemeColor)": theme.DefaultTheme,
}, fileContent)
c.Header("Content-Type", "text/html")
c.Header("Cache-Control", "public, no-cache")
c.String(200, finalHTML)
c.Abort()
return
}
if path == "/service-worker.js" {
if path == "/sw.js" || strings.HasPrefix(path, "/locales/") {
c.Header("Cache-Control", "public, no-cache")
} else if strings.HasPrefix(path, "/assets/") {
c.Header("Cache-Control", "public, max-age=31536000")
}
// 存在的静态文件

View File

@@ -1,144 +0,0 @@
package middleware
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/bootstrap"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"net/http"
"net/http/httptest"
"os"
"testing"
)
type StaticMock struct {
testMock.Mock
}
func (m StaticMock) Open(name string) (http.File, error) {
args := m.Called(name)
return args.Get(0).(http.File), args.Error(1)
}
func (m StaticMock) Exists(prefix string, filepath string) bool {
args := m.Called(prefix, filepath)
return args.Bool(0)
}
func TestFrontendFileHandler(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 静态资源未加载
{
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// index.html 不存在
{
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(&os.File{}, errors.New("error"))
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// index.html 读取失败
{
file, _ := util.CreatNestedFile("tests/index.html")
file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// 成功且命中
{
file, _ := util.CreatNestedFile("tests/index.html")
defer file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
cache.Set("setting_siteName", "cloudreve", 0)
cache.Set("setting_siteKeywords", "cloudreve", 0)
cache.Set("setting_siteScript", "cloudreve", 0)
cache.Set("setting_pwa_small_icon", "cloudreve", 0)
TestFunc(c)
asserts.True(c.IsAborted())
}
// 成功且命中静态文件
{
file, _ := util.CreatNestedFile("tests/index.html")
defer file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
testStatic.On("Exists", "/", "/2").
Return(true)
testStatic.On("Open", "/2").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/2", nil)
TestFunc(c)
asserts.True(c.IsAborted())
testStatic.AssertExpectations(t)
}
// API 相关跳过
{
for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} {
file, _ := util.CreatNestedFile("tests/index.html")
defer file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", reqPath, nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
}
}

View File

@@ -1,7 +1,7 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-gonic/gin"
)

View File

@@ -1,37 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestMockHelper(t *testing.T) {
asserts := assert.New(t)
MockHelperFunc := MockHelper()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
// 写入session
{
SessionMock["test"] = "pass"
Session("test")(c)
MockHelperFunc(c)
asserts.Equal("pass", util.GetSession(c, "test").(string))
}
// 写入context
{
ContextMock["test"] = "pass"
MockHelperFunc(c)
test, exist := c.Get("test")
asserts.True(exist)
asserts.Equal("pass", test.(string))
}
}

View File

@@ -1,14 +1,14 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/sessionstore"
"net/http"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
@@ -16,11 +16,11 @@ import (
// Store session存储
var Store sessions.Store
// Session 初始化session
func Session(secret string) gin.HandlerFunc {
// Redis设置不为空且非测试模式时使用Redis
Store = sessionstore.NewStore(cache.Store, []byte(secret))
const SessionName = "cloudreve-session"
// Session 初始化session
func Session(dep dependency.Dep) gin.HandlerFunc {
Store = sessionstore.NewStore(dep.KV(), []byte(dep.ConfigProvider().System().SessionSecret))
sameSiteMode := http.SameSiteDefaultMode
switch strings.ToLower(conf.CORSConfig.SameSite) {
case "default":
@@ -42,7 +42,7 @@ func Session(secret string) gin.HandlerFunc {
Secure: conf.CORSConfig.Secure,
})
return sessions.Sessions("cloudreve-session", Store)
return sessions.Sessions(SessionName, Store)
}
// CSRFInit 初始化CSRF标记
@@ -61,7 +61,7 @@ func CSRFCheck() gin.HandlerFunc {
return
}
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "Invalid origin", nil))
c.JSON(200, serializer.ErrDeprecated(serializer.CodeNoPermissionErr, "Invalid origin", nil))
c.Abort()
}
}

View File

@@ -1,64 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestSession(t *testing.T) {
asserts := assert.New(t)
{
handler := Session("2333")
asserts.NotNil(handler)
asserts.NotNil(Store)
asserts.IsType(emptyFunc(), handler)
}
}
func emptyFunc() gin.HandlerFunc {
return func(c *gin.Context) {}
}
func TestCSRFInit(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
sessionFunc := Session("233")
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFInit()(c)
asserts.True(util.GetSession(c, "CSRF").(bool))
}
}
func TestCSRFCheck(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
sessionFunc := Session("233")
// 通过检查
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFInit()(c)
CSRFCheck()(c)
asserts.False(c.IsAborted())
}
// 未通过检查
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFCheck()(c)
asserts.True(c.IsAborted())
}
}

View File

@@ -1,133 +0,0 @@
package middleware
import (
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
// ShareOwner 检查当前登录用户是否为分享所有者
func ShareOwner() gin.HandlerFunc {
return func(c *gin.Context) {
var user *model.User
if userCtx, ok := c.Get("user"); ok {
user = userCtx.(*model.User)
} else {
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, "", nil))
c.Abort()
return
}
if share, ok := c.Get("share"); ok {
if share.(*model.Share).Creator().ID != user.ID {
c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil))
c.Abort()
return
}
}
c.Next()
}
}
// ShareAvailable 检查分享是否可用
func ShareAvailable() gin.HandlerFunc {
return func(c *gin.Context) {
var user *model.User
if userCtx, ok := c.Get("user"); ok {
user = userCtx.(*model.User)
} else {
user = model.NewAnonymousUser()
}
share := model.GetShareByHashID(c.Param("id"))
if share == nil || !share.IsAvailable() {
c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil))
c.Abort()
return
}
c.Set("user", user)
c.Set("share", share)
c.Next()
}
}
// ShareCanPreview 检查分享是否可被预览
func ShareCanPreview() gin.HandlerFunc {
return func(c *gin.Context) {
if share, ok := c.Get("share"); ok {
if share.(*model.Share).PreviewEnabled {
c.Next()
return
}
c.JSON(200, serializer.Err(serializer.CodeDisabledSharePreview, "",
nil))
c.Abort()
return
}
c.Abort()
}
}
// CheckShareUnlocked 检查分享是否已解锁
func CheckShareUnlocked() gin.HandlerFunc {
return func(c *gin.Context) {
if shareCtx, ok := c.Get("share"); ok {
share := shareCtx.(*model.Share)
// 分享是否已解锁
if share.Password != "" {
sessionKey := fmt.Sprintf("share_unlock_%d", share.ID)
unlocked := util.GetSession(c, sessionKey) != nil
if !unlocked {
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr,
"", nil))
c.Abort()
return
}
}
c.Next()
return
}
c.Abort()
}
}
// BeforeShareDownload 分享被下载前的检查
func BeforeShareDownload() gin.HandlerFunc {
return func(c *gin.Context) {
if shareCtx, ok := c.Get("share"); ok {
if userCtx, ok := c.Get("user"); ok {
share := shareCtx.(*model.Share)
user := userCtx.(*model.User)
// 检查用户是否可以下载此分享的文件
err := share.CanBeDownloadBy(user)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
nil))
c.Abort()
return
}
// 对积分、下载次数进行更新
err = share.DownloadBy(user, c)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
nil))
c.Abort()
return
}
c.Next()
return
}
}
c.Abort()
}
}

View File

@@ -1,190 +0,0 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestShareAvailable(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ShareAvailable()
// 分享不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", "empty"},
}
testFunc(c)
asserts.True(c.IsAborted())
}
// 通过
{
conf.SystemConfig.HashIDSalt = ""
// 用户组
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
mock.ExpectQuery("SELECT(.+)shares(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "remain_downloads", "source_id"}).
AddRow(1, 1, 2),
)
mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", "x9T4"},
}
testFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
asserts.NotNil(c.Get("user"))
asserts.NotNil(c.Get("share"))
}
}
func TestShareCanPreview(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ShareCanPreview()
// 无分享上下文
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
}
// 可以预览
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{PreviewEnabled: true})
testFunc(c)
asserts.False(c.IsAborted())
}
// 未开启预览
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{PreviewEnabled: false})
testFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCheckShareUnlocked(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := CheckShareUnlocked()
// 无分享上下文
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
}
// 无密码
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestBeforeShareDownload(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := BeforeShareDownload()
// 无分享上下文
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 用户不能下载
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
c.Set("user", &model.User{
Group: model.Group{OptionsSerialized: model.GroupOption{}},
})
testFunc(c)
asserts.True(c.IsAborted())
}
// 可以下载
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
c.Set("user", &model.User{
Model: gorm.Model{ID: 1},
Group: model.Group{OptionsSerialized: model.GroupOption{
ShareDownload: true,
}},
})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestShareOwner(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ShareOwner()
// 未登录
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 非用户所创建分享
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{User: model.User{Model: gorm.Model{ID: 1}}})
c.Set("user", &model.User{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 正常
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
c.Set("user", &model.User{})
testFunc(c)
asserts.False(c.IsAborted())
}
}

View File

@@ -1,22 +1,21 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/cloudreve/Cloudreve/v4/pkg/wopi"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)
const (
WopiSessionCtx = "wopi_session"
)
// WopiWriteAccess validates if write access is obtained.
func WopiWriteAccess() gin.HandlerFunc {
return func(c *gin.Context) {
session := c.MustGet(WopiSessionCtx).(*wopi.SessionCache)
session := c.MustGet(wopi.WopiSessionCtx).(*wopi.SessionCache)
if session.Action != wopi.ActionEdit {
c.Status(http.StatusNotFound)
c.Header(wopi.ServerErrorHeader, "read-only access")
@@ -28,8 +27,12 @@ func WopiWriteAccess() gin.HandlerFunc {
}
}
func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
func ViewerSessionValidation() gin.HandlerFunc {
return func(c *gin.Context) {
dep := dependency.FromContext(c)
store := dep.KV()
settings := dep.SettingProvider()
accessToken := strings.Split(c.Query(wopi.AccessTokenQuery), ".")
if len(accessToken) != 2 {
c.Status(http.StatusForbidden)
@@ -38,7 +41,7 @@ func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
return
}
sessionRaw, exist := store.Get(wopi.SessionCachePrefix + accessToken[0])
sessionRaw, exist := store.Get(manager.ViewerSessionCachePrefix + accessToken[0])
if !exist {
c.Status(http.StatusForbidden)
c.Header(wopi.ServerErrorHeader, "invalid access token")
@@ -46,25 +49,47 @@ func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
return
}
session := sessionRaw.(wopi.SessionCache)
user, err := model.GetActiveUserByID(session.UserID)
if err != nil {
session := sessionRaw.(manager.ViewerSessionCache)
if err := SetUserCtx(c, session.UserID); err != nil {
c.Status(http.StatusInternalServerError)
c.Header(wopi.ServerErrorHeader, "user not found")
c.Abort()
return
}
fileID := c.MustGet("object_id").(uint)
if fileID != session.FileID {
c.Status(http.StatusInternalServerError)
c.Header(wopi.ServerErrorHeader, "file not found")
fileId := hashid.FromContext(c)
if fileId != session.FileID {
c.Status(http.StatusForbidden)
c.Header(wopi.ServerErrorHeader, "invalid file")
c.Abort()
return
}
c.Set("user", &user)
c.Set(WopiSessionCtx, &session)
// Check if the viewer is still available
viewers := settings.FileViewers(c)
var v *setting.Viewer
for _, group := range viewers {
for _, viewer := range group.Viewers {
if viewer.ID == session.ViewerID && !viewer.Disabled {
v = &viewer
break
}
}
if v != nil {
break
}
}
if v == nil {
c.Status(http.StatusInternalServerError)
c.Header(wopi.ServerErrorHeader, "viewer not found")
c.Abort()
return
}
util.WithValue(c, manager.ViewerCtx{}, v)
util.WithValue(c, manager.ViewerSessionCacheCtx{}, &session)
c.Next()
}
}

View File

@@ -1,112 +0,0 @@
package middleware
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestWopiWriteAccess(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := WopiWriteAccess()
// deny preview only session
{
c, _ := gin.CreateTestContext(rec)
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview})
testFunc(c)
asserts.True(c.IsAborted())
}
// pass
{
c, _ := gin.CreateTestContext(rec)
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestWopiAccessValidation(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
mockWopi := &wopimock.WopiClientMock{}
mockCache := cache.NewMemoStore()
testFunc := WopiAccessValidation(mockWopi, mockCache)
// malformed access token
{
c, _ := gin.CreateTestContext(rec)
c.AddParam(wopi.AccessTokenQuery, "000")
testFunc(c)
asserts.True(c.IsAborted())
}
// session key not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
testFunc(c)
asserts.True(c.IsAborted())
}
// user key not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
testFunc(c)
asserts.True(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
// file not found
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
c.Set("object_id", uint(0))
testFunc(c)
asserts.True(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
// all pass
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
c.Set("object_id", uint(1))
testFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotPanics(func() {
c.MustGet(WopiSessionCtx)
})
asserts.NotPanics(func() {
c.MustGet("user")
})
}
}