Init V4 community edition (#2265)

* Init V4 community edition

* Init V4 community edition
This commit is contained in:
AaronLiu
2025-04-20 17:31:25 +08:00
committed by GitHub
parent da4e44b77a
commit 21d158db07
597 changed files with 119415 additions and 41692 deletions

454
inventory/client.go Normal file
View File

@@ -0,0 +1,454 @@
package inventory
import (
"context"
rawsql "database/sql"
"database/sql/driver"
"fmt"
"os"
"time"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/group"
"github.com/cloudreve/Cloudreve/v4/ent/node"
_ "github.com/cloudreve/Cloudreve/v4/ent/runtime"
"github.com/cloudreve/Cloudreve/v4/ent/setting"
"github.com/cloudreve/Cloudreve/v4/ent/storagepolicy"
"github.com/cloudreve/Cloudreve/v4/inventory/debug"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"modernc.org/sqlite"
)
const (
DBVersionPrefix = "db_version_"
EnvDefaultOverwritePrefix = "CR_SETTING_DEFAULT_"
EnvEnableAria2 = "CR_ENABLE_ARIA2"
)
// InitializeDBClient runs migration and returns a new ent.Client with additional configurations
// for hooks and interceptors.
func InitializeDBClient(l logging.Logger,
client *ent.Client, kv cache.Driver, requiredDbVersion string) (*ent.Client, error) {
ctx := context.WithValue(context.Background(), logging.LoggerCtx{}, l)
if needMigration(client, ctx, requiredDbVersion) {
// Run the auto migration tool.
if err := migrate(l, client, ctx, kv, requiredDbVersion); err != nil {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
} else {
l.Info("Database schema is up to date.")
}
//createMockData(client, ctx)
return client, nil
}
// NewRawEntClient returns a new ent.Client without additional configurations.
func NewRawEntClient(l logging.Logger, config conf.ConfigProvider) (*ent.Client, error) {
l.Info("Initializing database connection...")
dbConfig := config.Database()
confDBType := dbConfig.Type
if confDBType == conf.SQLite3DB || confDBType == "" {
confDBType = conf.SQLiteDB
}
var (
err error
client *sql.Driver
)
switch confDBType {
case conf.SQLiteDB:
dbFile := util.RelativePath(dbConfig.DBFile)
l.Info("Connect to SQLite database %q.", dbFile)
client, err = sql.Open("sqlite3", util.RelativePath(dbConfig.DBFile))
case conf.PostgresDB:
l.Info("Connect to Postgres database %q.", dbConfig.Host)
client, err = sql.Open("postgres", fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
dbConfig.Host,
dbConfig.User,
dbConfig.Password,
dbConfig.Name,
dbConfig.Port))
case conf.MySqlDB, conf.MsSqlDB:
l.Info("Connect to MySQL/SQLServer database %q.", dbConfig.Host)
var host string
if dbConfig.UnixSocket {
host = fmt.Sprintf("unix(%s)",
dbConfig.Host)
} else {
host = fmt.Sprintf("(%s:%d)",
dbConfig.Host,
dbConfig.Port)
}
client, err = sql.Open(string(confDBType), fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local",
dbConfig.User,
dbConfig.Password,
host,
dbConfig.Name,
dbConfig.Charset))
default:
return nil, fmt.Errorf("unsupported database type %q", confDBType)
}
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Set connection pool
db := client.DB()
db.SetMaxIdleConns(50)
if confDBType == "sqlite" || confDBType == "UNSET" {
db.SetMaxOpenConns(1)
} else {
db.SetMaxOpenConns(100)
}
// Set timeout
db.SetConnMaxLifetime(time.Second * 30)
driverOpt := ent.Driver(client)
// Enable verbose logging for debug mode.
if config.System().Debug {
l.Debug("Debug mode is enabled for DB client.")
driverOpt = ent.Driver(debug.DebugWithContext(client, func(ctx context.Context, i ...any) {
logging.FromContext(ctx).Debug(i[0].(string), i[1:]...)
}))
}
return ent.NewClient(driverOpt), nil
}
type sqlite3Driver struct {
*sqlite.Driver
}
type sqlite3DriverConn interface {
Exec(string, []driver.Value) (driver.Result, error)
}
func (d sqlite3Driver) Open(name string) (conn driver.Conn, err error) {
conn, err = d.Driver.Open(name)
if err != nil {
return
}
_, err = conn.(sqlite3DriverConn).Exec("PRAGMA foreign_keys = ON;", nil)
if err != nil {
_ = conn.Close()
}
return
}
func init() {
rawsql.Register("sqlite3", sqlite3Driver{Driver: &sqlite.Driver{}})
}
// needMigration exams if required schema version is satisfied.
func needMigration(client *ent.Client, ctx context.Context, requiredDbVersion string) bool {
c, _ := client.Setting.Query().Where(setting.NameEQ(DBVersionPrefix + requiredDbVersion)).Count(ctx)
return c == 0
}
func migrate(l logging.Logger, client *ent.Client, ctx context.Context, kv cache.Driver, requiredDbVersion string) error {
l.Info("Start initializing database schema...")
l.Info("Creating basic table schema...")
if err := client.Schema.Create(ctx); err != nil {
return fmt.Errorf("Failed creating schema resources: %w", err)
}
migrateDefaultSettings(l, client, ctx, kv)
if err := migrateDefaultStoragePolicy(l, client, ctx); err != nil {
return fmt.Errorf("failed migrating default storage policy: %w", err)
}
if err := migrateSysGroups(l, client, ctx); err != nil {
return fmt.Errorf("failed migrating default storage policy: %w", err)
}
client.Setting.Create().SetName(DBVersionPrefix + requiredDbVersion).SetValue("installed").Save(ctx)
return nil
}
func migrateDefaultSettings(l logging.Logger, client *ent.Client, ctx context.Context, kv cache.Driver) {
// clean kv cache
if err := kv.DeleteAll(); err != nil {
l.Warning("Failed to remove all KV entries while schema migration: %s", err)
}
// List existing settings into a map
existingSettings := make(map[string]struct{})
settings, err := client.Setting.Query().All(ctx)
if err != nil {
l.Warning("Failed to query existing settings: %s", err)
}
for _, s := range settings {
existingSettings[s.Name] = struct{}{}
}
l.Info("Insert default settings...")
for k, v := range DefaultSettings {
if _, ok := existingSettings[k]; ok {
l.Debug("Skip inserting setting %s, already exists.", k)
continue
}
if override, ok := os.LookupEnv(EnvDefaultOverwritePrefix + k); ok {
l.Info("Override default setting %q with env value %q", k, override)
v = override
}
client.Setting.Create().SetName(k).SetValue(v).SaveX(ctx)
}
}
func migrateDefaultStoragePolicy(l logging.Logger, client *ent.Client, ctx context.Context) error {
if _, err := client.StoragePolicy.Query().Where(storagepolicy.ID(1)).First(ctx); err == nil {
l.Info("Default storage policy (ID=1) already exists, skip migrating.")
return nil
}
l.Info("Insert default storage policy...")
if _, err := client.StoragePolicy.Create().
SetName("Default storage policy").
SetType(types.PolicyTypeLocal).
SetDirNameRule(util.DataPath("uploads/{uid}/{path}")).
SetFileNameRule("{uid}_{randomkey8}_{originname}").
SetSettings(&types.PolicySetting{
ChunkSize: 25 << 20, // 25MB
PreAllocate: true,
}).
Save(ctx); err != nil {
return fmt.Errorf("failed to create default storage policy: %w", err)
}
return nil
}
func migrateSysGroups(l logging.Logger, client *ent.Client, ctx context.Context) error {
if err := migrateAdminGroup(l, client, ctx); err != nil {
return err
}
if err := migrateUserGroup(l, client, ctx); err != nil {
return err
}
if err := migrateAnonymousGroup(l, client, ctx); err != nil {
return err
}
if err := migrateMasterNode(l, client, ctx); err != nil {
return err
}
return nil
}
func migrateAdminGroup(l logging.Logger, client *ent.Client, ctx context.Context) error {
if _, err := client.Group.Query().Where(group.ID(1)).First(ctx); err == nil {
l.Info("Default admin group (ID=1) already exists, skip migrating.")
return nil
}
l.Info("Insert default admin group...")
permissions := &boolset.BooleanSet{}
boolset.Sets(map[types.GroupPermission]bool{
types.GroupPermissionIsAdmin: true,
types.GroupPermissionShare: true,
types.GroupPermissionWebDAV: true,
types.GroupPermissionWebDAVProxy: true,
types.GroupPermissionArchiveDownload: true,
types.GroupPermissionArchiveTask: true,
types.GroupPermissionShareDownload: true,
types.GroupPermissionRemoteDownload: true,
types.GroupPermissionRedirectedSource: true,
types.GroupPermissionAdvanceDelete: true,
types.GroupPermissionIgnoreFileOwnership: true,
// TODO: review default permission
}, permissions)
if _, err := client.Group.Create().
SetName("Admin").
SetStoragePoliciesID(1).
SetMaxStorage(1 * constants.TB). // 1 TB default storage
SetPermissions(permissions).
SetSettings(&types.GroupSetting{
SourceBatchSize: 1000,
Aria2BatchSize: 50,
MaxWalkedFiles: 100000,
TrashRetention: 7 * 24 * 3600,
RedirectedSource: true,
}).
Save(ctx); err != nil {
return fmt.Errorf("failed to create default admin group: %w", err)
}
return nil
}
func migrateUserGroup(l logging.Logger, client *ent.Client, ctx context.Context) error {
if _, err := client.Group.Query().Where(group.ID(2)).First(ctx); err == nil {
l.Info("Default user group (ID=2) already exists, skip migrating.")
return nil
}
l.Info("Insert default user group...")
permissions := &boolset.BooleanSet{}
boolset.Sets(map[types.GroupPermission]bool{
types.GroupPermissionShare: true,
types.GroupPermissionShareDownload: true,
types.GroupPermissionRedirectedSource: true,
}, permissions)
if _, err := client.Group.Create().
SetName("User").
SetStoragePoliciesID(1).
SetMaxStorage(1 * constants.GB). // 1 GB default storage
SetPermissions(permissions).
SetSettings(&types.GroupSetting{
SourceBatchSize: 10,
Aria2BatchSize: 1,
MaxWalkedFiles: 100000,
TrashRetention: 7 * 24 * 3600,
RedirectedSource: true,
}).
Save(ctx); err != nil {
return fmt.Errorf("failed to create default user group: %w", err)
}
return nil
}
func migrateAnonymousGroup(l logging.Logger, client *ent.Client, ctx context.Context) error {
if _, err := client.Group.Query().Where(group.ID(AnonymousGroupID)).First(ctx); err == nil {
l.Info("Default anonymous group (ID=3) already exists, skip migrating.")
return nil
}
l.Info("Insert default anonymous group...")
permissions := &boolset.BooleanSet{}
boolset.Sets(map[types.GroupPermission]bool{
types.GroupPermissionIsAnonymous: true,
types.GroupPermissionShareDownload: true,
}, permissions)
if _, err := client.Group.Create().
SetName("Anonymous").
SetPermissions(permissions).
SetSettings(&types.GroupSetting{
MaxWalkedFiles: 100000,
RedirectedSource: true,
}).
Save(ctx); err != nil {
return fmt.Errorf("failed to create default anonymous group: %w", err)
}
return nil
}
func migrateMasterNode(l logging.Logger, client *ent.Client, ctx context.Context) error {
if _, err := client.Node.Query().Where(node.TypeEQ(node.TypeMaster)).First(ctx); err == nil {
l.Info("Default master node already exists, skip migrating.")
return nil
}
capabilities := &boolset.BooleanSet{}
boolset.Sets(map[types.NodeCapability]bool{
types.NodeCapabilityCreateArchive: true,
types.NodeCapabilityExtractArchive: true,
types.NodeCapabilityRemoteDownload: true,
}, capabilities)
stm := client.Node.Create().
SetType(node.TypeMaster).
SetCapabilities(capabilities).
SetName("Master").
SetSettings(&types.NodeSetting{
Provider: types.DownloaderProviderAria2,
}).
SetStatus(node.StatusActive)
_, enableAria2 := os.LookupEnv(EnvEnableAria2)
if enableAria2 {
l.Info("Aria2 is override as enabled.")
stm.SetSettings(&types.NodeSetting{
Provider: types.DownloaderProviderAria2,
Aria2Setting: &types.Aria2Setting{
Server: "http://127.0.0.1:6800/jsonrpc",
},
})
}
l.Info("Insert default master node...")
if _, err := stm.Save(ctx); err != nil {
return fmt.Errorf("failed to create default master node: %w", err)
}
return nil
}
func createMockData(client *ent.Client, ctx context.Context) {
//userCount := 100
//folderCount := 10000
//fileCount := 25000
//
//// create users
//pwdDigest, _ := digestPassword("52121225")
//userCreates := make([]*ent.UserCreate, userCount)
//for i := 0; i < userCount; i++ {
// nick := uuid.Must(uuid.NewV4()).String()
// userCreates[i] = client.User.Create().
// SetEmail(nick + "@cloudreve.org").
// SetNick(nick).
// SetPassword(pwdDigest).
// SetStatus(user.StatusActive).
// SetGroupID(1)
//}
//users, err := client.User.CreateBulk(userCreates...).Save(ctx)
//if err != nil {
// panic(err)
//}
//
//// Create root folder
//rootFolderCreates := make([]*ent.FileCreate, userCount)
//folderIds := make([][]int, 0, folderCount*userCount+userCount)
//for i, user := range users {
// rootFolderCreates[i] = client.File.Create().
// SetName(RootFolderName).
// SetOwnerID(user.ID).
// SetType(int(FileTypeFolder))
//}
//rootFolders, err := client.File.CreateBulk(rootFolderCreates...).Save(ctx)
//for _, rootFolders := range rootFolders {
// folderIds = append(folderIds, []int{rootFolders.ID, rootFolders.OwnerID})
//}
//if err != nil {
// panic(err)
//}
//
//// create random folder
//for i := 0; i < folderCount*userCount; i++ {
// parent := lo.Sample(folderIds)
// res := client.File.Create().
// SetName(uuid.Must(uuid.NewV4()).String()).
// SetType(int(FileTypeFolder)).
// SetOwnerID(parent[1]).
// SetFileChildren(parent[0]).
// SaveX(ctx)
// folderIds = append(folderIds, []int{res.ID, res.OwnerID})
//}
for i := 0; i < 255; i++ {
fmt.Printf("%d/", i)
}
}

124
inventory/common.go Normal file
View File

@@ -0,0 +1,124 @@
package inventory
import (
"encoding/base64"
"encoding/json"
"entgo.io/ent/dialect/sql"
"fmt"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"time"
)
type (
OrderDirection string
PaginationArgs struct {
UseCursorPagination bool
Page int
PageToken string
PageSize int
OrderBy string
Order OrderDirection
}
PaginationResults struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalItems int `json:"total_items,omitempty"`
NextPageToken string `json:"next_token,omitempty"`
IsCursor bool `json:"is_cursor,omitempty"`
}
PageToken struct {
Time *time.Time `json:"time,omitempty"`
ID int `json:"-"`
IDHash string `json:"id,omitempty"`
String string `json:"string,omitempty"`
Int int `json:"int,omitempty"`
StartWithFile bool `json:"start_with_file,omitempty"`
}
)
const (
OrderDirectionAsc = OrderDirection("asc")
OrderDirectionDesc = OrderDirection("desc")
)
var (
ErrTooManyArguments = fmt.Errorf("too many arguments")
)
func pageTokenFromString(s string, hasher hashid.Encoder, idType int) (*PageToken, error) {
sB64Decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 for page token: %w", err)
}
token := &PageToken{}
if err := json.Unmarshal(sB64Decoded, token); err != nil {
return nil, fmt.Errorf("failed to unmarshal page token: %w", err)
}
id, err := hasher.Decode(token.IDHash, idType)
if err != nil {
return nil, fmt.Errorf("failed to decode id: %w", err)
}
if token.Time == nil {
token.Time = &time.Time{}
}
token.ID = id
return token, nil
}
func (p *PageToken) Encode(hasher hashid.Encoder, encodeFunc hashid.EncodeFunc) (string, error) {
p.IDHash = encodeFunc(hasher, p.ID)
res, err := json.Marshal(p)
if err != nil {
return "", fmt.Errorf("failed to marshal page token: %w", err)
}
return base64.StdEncoding.EncodeToString(res), nil
}
// sqlParamLimit returns the max number of sql parameters.
func sqlParamLimit(dbType conf.DBType) int {
switch dbType {
case conf.PostgresDB:
return 34464
case conf.SQLiteDB, conf.SQLite3DB:
// https://www.sqlite.org/limits.html
return 32766
default:
return 32766
}
}
// getOrderTerm returns the order term for ent.
func getOrderTerm(d OrderDirection) sql.OrderTermOption {
switch d {
case OrderDirectionDesc:
return sql.OrderDesc()
default:
return sql.OrderAsc()
}
}
func capPageSize(maxSQlParam, preferredSize, margin int) int {
// Page size should not be bigger than max SQL parameter
pageSize := preferredSize
if maxSQlParam > 0 && pageSize > maxSQlParam-margin || pageSize == 0 {
pageSize = maxSQlParam - margin
}
return pageSize
}
type StorageDiff map[int]int64
func (s *StorageDiff) Merge(diff StorageDiff) {
for k, v := range diff {
(*s)[k] += v
}
}

187
inventory/dav_account.go Normal file
View File

@@ -0,0 +1,187 @@
package inventory
import (
"context"
"entgo.io/ent/dialect/sql"
"fmt"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/davaccount"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/samber/lo"
)
type (
DavAccountClient interface {
TxOperator
// List returns a list of dav accounts with the given args.
List(ctx context.Context, args *ListDavAccountArgs) (*ListDavAccountResult, error)
// Create creates a new dav account.
Create(ctx context.Context, params *CreateDavAccountParams) (*ent.DavAccount, error)
// Update updates a dav account.
Update(ctx context.Context, id int, params *CreateDavAccountParams) (*ent.DavAccount, error)
// GetByIDAndUserID returns the dav account with given id and user id.
GetByIDAndUserID(ctx context.Context, id, userID int) (*ent.DavAccount, error)
// Delete deletes the dav account.
Delete(ctx context.Context, id int) error
}
ListDavAccountArgs struct {
*PaginationArgs
UserID int
}
ListDavAccountResult struct {
*PaginationResults
Accounts []*ent.DavAccount
}
CreateDavAccountParams struct {
UserID int
Name string
URI string
Password string
Options *boolset.BooleanSet
}
)
func NewDavAccountClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) DavAccountClient {
return &davAccountClient{
client: client,
hasher: hasher,
maxSQlParam: sqlParamLimit(dbType),
}
}
type davAccountClient struct {
maxSQlParam int
client *ent.Client
hasher hashid.Encoder
}
func (c *davAccountClient) SetClient(newClient *ent.Client) TxOperator {
return &davAccountClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
}
func (c *davAccountClient) GetClient() *ent.Client {
return c.client
}
func (c *davAccountClient) Create(ctx context.Context, params *CreateDavAccountParams) (*ent.DavAccount, error) {
account := c.client.DavAccount.Create().
SetOwnerID(params.UserID).
SetName(params.Name).
SetURI(params.URI).
SetPassword(params.Password).
SetOptions(params.Options)
return account.Save(ctx)
}
func (c *davAccountClient) GetByIDAndUserID(ctx context.Context, id, userID int) (*ent.DavAccount, error) {
return c.client.DavAccount.Query().
Where(davaccount.ID(id), davaccount.OwnerID(userID)).
First(ctx)
}
func (c *davAccountClient) Update(ctx context.Context, id int, params *CreateDavAccountParams) (*ent.DavAccount, error) {
account := c.client.DavAccount.UpdateOneID(id).
SetName(params.Name).
SetURI(params.URI).
SetOptions(params.Options)
return account.Save(ctx)
}
func (c *davAccountClient) Delete(ctx context.Context, id int) error {
return c.client.DavAccount.DeleteOneID(id).Exec(ctx)
}
func (c *davAccountClient) List(ctx context.Context, args *ListDavAccountArgs) (*ListDavAccountResult, error) {
query := c.listQuery(args)
var (
accounts []*ent.DavAccount
err error
paginationRes *PaginationResults
)
accounts, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
if err != nil {
return nil, fmt.Errorf("query failed with paginiation: %w", err)
}
return &ListDavAccountResult{
Accounts: accounts,
PaginationResults: paginationRes,
}, nil
}
func (c *davAccountClient) cursorPagination(ctx context.Context, query *ent.DavAccountQuery, args *ListDavAccountArgs, paramMargin int) ([]*ent.DavAccount, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(davaccount.ByID(sql.OrderDesc()))
var (
pageToken *PageToken
err error
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.DavAccountID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
}
queryPaged := getDavAccountCursorQuery(pageToken, query)
// Use page size + 1 to determine if there are more items to come
queryPaged.Limit(pageSize + 1)
logs, err := queryPaged.
All(ctx)
if err != nil {
return nil, nil, err
}
// More items to come
nextTokenStr := ""
if len(logs) > pageSize {
lastItem := logs[len(logs)-2]
nextToken, err := getDavAccountNextPageToken(c.hasher, lastItem)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
func (c *davAccountClient) listQuery(args *ListDavAccountArgs) *ent.DavAccountQuery {
query := c.client.DavAccount.Query()
if args.UserID > 0 {
query.Where(davaccount.OwnerID(args.UserID))
}
return query
}
// getDavAccountNextPageToken returns the next page token for the given last dav account.
func getDavAccountNextPageToken(hasher hashid.Encoder, last *ent.DavAccount) (string, error) {
token := &PageToken{
ID: last.ID,
}
return token.Encode(hasher, hashid.EncodeDavAccountID)
}
func getDavAccountCursorQuery(token *PageToken, query *ent.DavAccountQuery) *ent.DavAccountQuery {
if token != nil {
query.Where(davaccount.IDLT(token.ID))
}
return query
}

188
inventory/debug/debug.go Normal file
View File

@@ -0,0 +1,188 @@
package debug
import (
"context"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"fmt"
"github.com/google/uuid"
"time"
)
const strMaxLen = 102400
type SkipDbLogging struct{}
// DebugDriver is a driver that logs all driver operations.
type DebugDriver struct {
dialect.Driver // underlying driver.
log func(context.Context, ...any) // log function. defaults to log.Println.
}
// DebugWithContext gets a driver and a logging function, and returns
// a new debugged-driver that prints all outgoing operations with context.
func DebugWithContext(d dialect.Driver, logger func(context.Context, ...any)) dialect.Driver {
drv := &DebugDriver{d, logger}
return drv
}
// Exec logs its params and calls the underlying driver Exec method.
func (d *DebugDriver) Exec(ctx context.Context, query string, args, v any) error {
start := time.Now()
err := d.Driver.Exec(ctx, query, args, v)
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return err
}
d.log(ctx, fmt.Sprintf("driver.Exec: query=%v args=%v time=%v", query, args, time.Since(start)))
return err
}
// ExecContext logs its params and calls the underlying driver ExecContext method if it is supported.
func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
drv, ok := d.Driver.(interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Driver.ExecContext is not supported")
}
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return drv.ExecContext(ctx, query, args...)
}
d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args))
return drv.ExecContext(ctx, query, args...)
}
// Query logs its params and calls the underlying driver Query method.
func (d *DebugDriver) Query(ctx context.Context, query string, args, v any) error {
start := time.Now()
err := d.Driver.Query(ctx, query, args, v)
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return err
}
d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v time=%v", query, args, time.Since(start)))
return err
}
// QueryContext logs its params and calls the underlying driver QueryContext method if it is supported.
func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
drv, ok := d.Driver.(interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Driver.QueryContext is not supported")
}
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return drv.QueryContext(ctx, query, args...)
}
d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args))
return drv.QueryContext(ctx, query, args...)
}
// Tx adds an log-id for the transaction and calls the underlying driver Tx command.
func (d *DebugDriver) Tx(ctx context.Context) (dialect.Tx, error) {
tx, err := d.Driver.Tx(ctx)
if err != nil {
return nil, err
}
id := uuid.New().String()
d.log(ctx, fmt.Sprintf("driver.Tx(%s): started", id))
return &DebugTx{tx, id, d.log, ctx}, nil
}
// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported.
func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) {
drv, ok := d.Driver.(interface {
BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
})
if !ok {
return nil, fmt.Errorf("Driver.BeginTx is not supported")
}
tx, err := drv.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
id := uuid.New().String()
d.log(ctx, fmt.Sprintf("driver.BeginTx(%s): started", id))
return &DebugTx{tx, id, d.log, ctx}, nil
}
// DebugTx is a transaction implementation that logs all transaction operations.
type DebugTx struct {
dialect.Tx // underlying transaction.
id string // transaction logging id.
log func(context.Context, ...any) // log function. defaults to fmt.Println.
ctx context.Context // underlying transaction context.
}
// Exec logs its params and calls the underlying transaction Exec method.
func (d *DebugTx) Exec(ctx context.Context, query string, args, v any) error {
start := time.Now()
err := d.Tx.Exec(ctx, query, args, v)
printArgs := args
if argsArray, ok := args.([]interface{}); ok {
for i, argVal := range argsArray {
if argValStr, ok := argVal.(string); ok && len(argValStr) > strMaxLen {
printArgs.([]interface{})[i] = argValStr[:strMaxLen] + "...[Truncated]..."
}
}
}
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return err
}
d.log(ctx, fmt.Sprintf("Tx(%s).Exec: query=%v args=%v time=%v", d.id, query, args, time.Since(start)))
return err
}
// ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported.
func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
drv, ok := d.Tx.(interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
})
if !ok {
return nil, fmt.Errorf("Tx.ExecContext is not supported")
}
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return drv.ExecContext(ctx, query, args...)
}
d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args))
return drv.ExecContext(ctx, query, args...)
}
// Query logs its params and calls the underlying transaction Query method.
func (d *DebugTx) Query(ctx context.Context, query string, args, v any) error {
start := time.Now()
err := d.Tx.Query(ctx, query, args, v)
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return err
}
d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v time=%v", d.id, query, args, time.Since(start)))
return err
}
// QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported.
func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
drv, ok := d.Tx.(interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
})
if !ok {
return nil, fmt.Errorf("Tx.QueryContext is not supported")
}
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
return drv.QueryContext(ctx, query, args...)
}
d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args))
return drv.QueryContext(ctx, query, args...)
}
// Commit logs this step and calls the underlying transaction Commit method.
func (d *DebugTx) Commit() error {
d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id))
return d.Tx.Commit()
}
// Rollback logs this step and calls the underlying transaction Rollback method.
func (d *DebugTx) Rollback() error {
d.log(d.ctx, fmt.Sprintf("Tx(%s): rollbacked", d.id))
return d.Tx.Rollback()
}

70
inventory/direct_link.go Normal file
View File

@@ -0,0 +1,70 @@
package inventory
import (
"context"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/directlink"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
)
type (
DirectLinkClient interface {
TxOperator
// GetByNameID get direct link by name and id
GetByNameID(ctx context.Context, id int, name string) (*ent.DirectLink, error)
// GetByID get direct link by id
GetByID(ctx context.Context, id int) (*ent.DirectLink, error)
}
LoadDirectLinkFile struct{}
)
func NewDirectLinkClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) DirectLinkClient {
return &directLinkClient{
client: client,
hasher: hasher,
maxSQlParam: sqlParamLimit(dbType),
}
}
type directLinkClient struct {
maxSQlParam int
client *ent.Client
hasher hashid.Encoder
}
func (c *directLinkClient) SetClient(newClient *ent.Client) TxOperator {
return &directLinkClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
}
func (c *directLinkClient) GetClient() *ent.Client {
return c.client
}
func (d *directLinkClient) GetByID(ctx context.Context, id int) (*ent.DirectLink, error) {
return withDirectLinkEagerLoading(ctx, d.client.DirectLink.Query().Where(directlink.ID(id))).
First(ctx)
}
func (d *directLinkClient) GetByNameID(ctx context.Context, id int, name string) (*ent.DirectLink, error) {
res, err := withDirectLinkEagerLoading(ctx, d.client.DirectLink.Query().Where(directlink.ID(id), directlink.Name(name))).
First(ctx)
if err != nil {
return nil, err
}
// Increase download counter
_, _ = d.client.DirectLink.Update().Where(directlink.ID(res.ID)).SetDownloads(res.Downloads + 1).Save(ctx)
return res, nil
}
func withDirectLinkEagerLoading(ctx context.Context, q *ent.DirectLinkQuery) *ent.DirectLinkQuery {
if v, ok := ctx.Value(LoadDirectLinkFile{}).(bool); ok && v {
q.WithFile(func(m *ent.FileQuery) {
withFileEagerLoading(ctx, m)
})
}
return q
}

1113
inventory/file.go Normal file

File diff suppressed because it is too large Load Diff

503
inventory/file_utils.go Normal file
View File

@@ -0,0 +1,503 @@
package inventory
import (
"context"
"fmt"
"strings"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/entity"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/ent/metadata"
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/samber/lo"
)
func (f *fileClient) searchQuery(q *ent.FileQuery, args *SearchFileParameters, parents []*ent.File, ownerId int) *ent.FileQuery {
if len(parents) == 1 && parents[0] == nil {
q = q.Where(file.OwnerID(ownerId))
} else {
q = q.Where(
file.HasParentWith(
file.IDIn(lo.Map(parents, func(item *ent.File, index int) int {
return item.ID
})...,
),
),
)
}
if len(args.Name) > 0 {
namePredicates := lo.Map(args.Name, func(item string, index int) predicate.File {
// If start and ends with quotes, treat as exact match
if strings.HasPrefix(item, "\"") && strings.HasSuffix(item, "\"") {
return file.NameContains(strings.Trim(item, "\""))
}
// if contain wildcard, use transform to sql like
if strings.Contains(item, SearchWildcard) {
pattern := strings.ReplaceAll(item, SearchWildcard, "%")
if pattern[0] != '%' && pattern[len(pattern)-1] != '%' {
// if not start with wildcard, add prefix wildcard
pattern = "%" + pattern + "%"
}
return func(s *sql.Selector) {
s.Where(sql.Like(file.FieldName, pattern))
}
}
if args.CaseFolding {
return file.NameContainsFold(item)
}
return file.NameContains(item)
})
if args.NameOperatorOr {
q = q.Where(file.Or(namePredicates...))
} else {
q = q.Where(file.And(namePredicates...))
}
}
if args.Type != nil {
q = q.Where(file.TypeEQ(int(*args.Type)))
}
if len(args.Metadata) > 0 {
metaPredicates := lo.MapToSlice(args.Metadata, func(name string, value string) predicate.Metadata {
nameEq := metadata.NameEQ(value)
if name == "" {
return nameEq
} else {
valueContain := metadata.ValueContainsFold(value)
return metadata.And(metadata.NameEQ(name), valueContain)
}
})
metaPredicates = append(metaPredicates, metadata.IsPublic(true))
q.Where(file.HasMetadataWith(metadata.And(metaPredicates...)))
}
if args.SizeLte > 0 || args.SizeGte > 0 {
q = q.Where(file.SizeGTE(args.SizeGte), file.SizeLTE(args.SizeLte))
}
if args.CreatedAtLte != nil {
q = q.Where(file.CreatedAtLTE(*args.CreatedAtLte))
}
if args.CreatedAtGte != nil {
q = q.Where(file.CreatedAtGTE(*args.CreatedAtGte))
}
if args.UpdatedAtLte != nil {
q = q.Where(file.UpdatedAtLTE(*args.UpdatedAtLte))
}
if args.UpdatedAtGte != nil {
q = q.Where(file.UpdatedAtGTE(*args.UpdatedAtGte))
}
return q
}
// ChildFileQuery generates query for child file(s) of a given set of root
func (f *fileClient) childFileQuery(ownerID int, isSymbolic bool, root ...*ent.File) *ent.FileQuery {
rawQuery := f.client.File.Query()
if len(root) == 1 && root[0] != nil {
// Query children of one single root
rawQuery = f.client.File.QueryChildren(root[0])
} else if root[0] == nil {
// Query orphan files with owner ID
predicates := []predicate.File{
file.NameNEQ(RootFolderName),
}
if ownerID > 0 {
predicates = append(predicates, file.OwnerIDEQ(ownerID))
}
if isSymbolic {
predicates = append(predicates, file.And(file.IsSymbolic(true), file.FileChildrenNotNil()))
} else {
predicates = append(predicates, file.Not(file.HasParent()))
}
rawQuery = f.client.File.Query().Where(
file.And(predicates...),
)
} else {
// Query children of multiple roots
rawQuery.
Where(
file.HasParentWith(
file.IDIn(lo.Map(root, func(item *ent.File, index int) int {
return item.ID
})...),
),
)
}
return rawQuery
}
// batchInCondition returns a list of predicates that divide original group into smaller ones
// to bypass DB limitations.
func (f *fileClient) batchInCondition(pageSize, margin int, multiply int, ids []int) ([]predicate.File, [][]int) {
pageSize = capPageSize(f.maxSQlParam, pageSize, margin)
chunks := lo.Chunk(ids, max(pageSize/multiply, 1))
return lo.Map(chunks, func(item []int, index int) predicate.File {
return file.IDIn(item...)
}), chunks
}
func (f *fileClient) batchInConditionMetadataName(pageSize, margin int, multiply int, keys []string) ([]predicate.Metadata, [][]string) {
pageSize = capPageSize(f.maxSQlParam, pageSize, margin)
chunks := lo.Chunk(keys, max(pageSize/multiply, 1))
return lo.Map(chunks, func(item []string, index int) predicate.Metadata {
return metadata.NameIn(item...)
}), chunks
}
func (f *fileClient) batchInConditionEntityID(pageSize, margin int, multiply int, keys []int) ([]predicate.Entity, [][]int) {
pageSize = capPageSize(f.maxSQlParam, pageSize, margin)
chunks := lo.Chunk(keys, max(pageSize/multiply, 1))
return lo.Map(chunks, func(item []int, index int) predicate.Entity {
return entity.IDIn(item...)
}), chunks
}
// cursorPagination perform pagination with cursor, which is faster than fast pagination, but less flexible.
func (f *fileClient) cursorPagination(ctx context.Context, query *ent.FileQuery,
args *ListFileParameters, paramMargin int) ([]*ent.File, *PaginationResults, error) {
pageSize := capPageSize(f.maxSQlParam, args.PageSize, paramMargin)
query.Order(getFileOrderOption(args)...)
currentPage := 0
// Three types of query option
queryPaged := []*ent.FileQuery{
query.Clone().
Where(file.TypeEQ(int(types.FileTypeFolder))),
query.Clone().
Where(file.TypeEQ(int(types.FileTypeFile))),
query.Clone().
Where(file.TypeIn(int(types.FileTypeFolder), int(types.FileTypeFile))),
}
var (
pageToken *PageToken
err error
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, f.hasher, hashid.FileID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
}
queryPaged = getFileCursorQuery(args, pageToken, queryPaged)
// Use page size + 1 to determine if there are more items to come
queryPaged[0].Limit(pageSize + 1)
files, err := queryPaged[0].
All(ctx)
if err != nil {
return nil, nil, err
}
nextStartWithFile := false
if pageToken != nil && pageToken.StartWithFile {
nextStartWithFile = true
}
if len(files) < pageSize+1 && len(queryPaged) > 1 && !args.MixedType && !args.FolderOnly {
queryPaged[1].Limit(pageSize + 1 - len(files))
filesContinue, err := queryPaged[1].
All(ctx)
if err != nil {
return nil, nil, err
}
nextStartWithFile = true
files = append(files, filesContinue...)
}
// More items to come
nextTokenStr := ""
if len(files) > pageSize {
lastItem := files[len(files)-2]
nextToken, err := getFileNextPageToken(f.hasher, lastItem, args, nextStartWithFile)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(files, 0, uint(pageSize)), &PaginationResults{
Page: currentPage,
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
// offsetPagination perform traditional pagination with minor optimizations.
func (f *fileClient) offsetPagination(ctx context.Context, query *ent.FileQuery,
args *ListFileParameters, paramMargin int) ([]*ent.File, *PaginationResults, error) {
pageSize := capPageSize(f.maxSQlParam, args.PageSize, paramMargin)
queryWithoutOrder := query.Clone()
query.Order(getFileOrderOption(args)...)
// Count total items by type
var v []struct {
Type int `json:"type"`
Count int `json:"count"`
}
err := queryWithoutOrder.Clone().
GroupBy(file.FieldType).
Aggregate(ent.Count()).
Scan(ctx, &v)
if err != nil {
return nil, nil, err
}
folderCount := 0
fileCount := 0
for _, item := range v {
if item.Type == int(types.FileTypeFolder) {
folderCount = item.Count
} else {
fileCount = item.Count
}
}
allFiles := make([]*ent.File, 0, pageSize)
folderLimit := 0
if (args.Page+1)*pageSize > folderCount {
folderLimit = folderCount - args.Page*pageSize
if folderLimit < 0 {
folderLimit = 0
}
} else {
folderLimit = pageSize
}
if folderLimit <= pageSize && folderLimit > 0 {
// Folder still remains
folders, err := query.Clone().
Limit(folderLimit).
Offset(args.Page * pageSize).
Where(file.TypeEQ(int(types.FileTypeFolder))).All(ctx)
if err != nil {
return nil, nil, err
}
allFiles = append(allFiles, folders...)
}
if folderLimit < pageSize {
files, err := query.Clone().
Limit(pageSize - folderLimit).
Offset((args.Page * pageSize) + folderLimit - folderCount).
Where(file.TypeEQ(int(types.FileTypeFile))).
All(ctx)
if err != nil {
return nil, nil, err
}
allFiles = append(allFiles, files...)
}
return allFiles, &PaginationResults{
TotalItems: folderCount + fileCount,
Page: args.Page,
PageSize: pageSize,
}, nil
}
func withFileEagerLoading(ctx context.Context, q *ent.FileQuery) *ent.FileQuery {
if v, ok := ctx.Value(LoadFileEntity{}).(bool); ok && v {
q.WithEntities(func(m *ent.EntityQuery) {
m.Order(ent.Desc(entity.FieldID))
withEntityEagerLoading(ctx, m)
})
}
if v, ok := ctx.Value(LoadFileMetadata{}).(bool); ok && v {
q.WithMetadata()
}
if v, ok := ctx.Value(LoadFilePublicMetadata{}).(bool); ok && v {
q.WithMetadata(func(m *ent.MetadataQuery) {
m.Where(metadata.IsPublic(true))
})
}
if v, ok := ctx.Value(LoadFileShare{}).(bool); ok && v {
q.WithShares()
}
if v, ok := ctx.Value(LoadFileUser{}).(bool); ok && v {
q.WithOwner(func(query *ent.UserQuery) {
withUserEagerLoading(ctx, query)
})
}
if v, ok := ctx.Value(LoadFileDirectLink{}).(bool); ok && v {
q.WithDirectLinks()
}
return q
}
func withEntityEagerLoading(ctx context.Context, q *ent.EntityQuery) *ent.EntityQuery {
if v, ok := ctx.Value(LoadEntityUser{}).(bool); ok && v {
q.WithUser()
}
if v, ok := ctx.Value(LoadEntityStoragePolicy{}).(bool); ok && v {
q.WithStoragePolicy()
}
if v, ok := ctx.Value(LoadEntityFile{}).(bool); ok && v {
q.WithFile(func(fq *ent.FileQuery) {
withFileEagerLoading(ctx, fq)
})
}
return q
}
func getFileOrderOption(args *ListFileParameters) []file.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case file.FieldName:
return []file.OrderOption{file.ByName(orderTerm), file.ByID(orderTerm)}
case file.FieldSize:
return []file.OrderOption{file.BySize(orderTerm), file.ByID(orderTerm)}
case file.FieldUpdatedAt:
return []file.OrderOption{file.ByUpdatedAt(orderTerm), file.ByID(orderTerm)}
default:
return []file.OrderOption{file.ByID(orderTerm)}
}
}
func getEntityOrderOption(args *ListEntityParameters) []entity.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case entity.FieldSize:
return []entity.OrderOption{entity.BySize(orderTerm), entity.ByID(orderTerm)}
case entity.FieldUpdatedAt:
return []entity.OrderOption{entity.ByUpdatedAt(orderTerm), entity.ByID(orderTerm)}
case entity.FieldReferenceCount:
return []entity.OrderOption{entity.ByReferenceCount(orderTerm), entity.ByID(orderTerm)}
default:
return []entity.OrderOption{entity.ByID(orderTerm)}
}
}
var fileCursorQuery = map[string]map[bool]func(token *PageToken) predicate.File{
file.FieldName: {
true: func(token *PageToken) predicate.File {
return file.Or(
file.NameLT(token.String),
file.And(file.Name(token.String), file.IDLT(token.ID)),
)
},
false: func(token *PageToken) predicate.File {
return file.Or(
file.NameGT(token.String),
file.And(file.Name(token.String), file.IDGT(token.ID)),
)
},
},
file.FieldSize: {
true: func(token *PageToken) predicate.File {
return file.Or(
file.SizeLT(int64(token.Int)),
file.And(file.Size(int64(token.Int)), file.IDLT(token.ID)),
)
},
false: func(token *PageToken) predicate.File {
return file.Or(
file.SizeGT(int64(token.Int)),
file.And(file.Size(int64(token.Int)), file.IDGT(token.ID)),
)
},
},
file.FieldCreatedAt: {
true: func(token *PageToken) predicate.File {
return file.IDLT(token.ID)
},
false: func(token *PageToken) predicate.File {
return file.IDGT(token.ID)
},
},
file.FieldUpdatedAt: {
true: func(token *PageToken) predicate.File {
return file.Or(
file.UpdatedAtLT(*token.Time),
file.And(file.UpdatedAt(*token.Time), file.IDLT(token.ID)),
)
},
false: func(token *PageToken) predicate.File {
return file.Or(
file.UpdatedAtGT(*token.Time),
file.And(file.UpdatedAt(*token.Time), file.IDGT(token.ID)),
)
},
},
file.FieldID: {
true: func(token *PageToken) predicate.File {
return file.IDLT(token.ID)
},
false: func(token *PageToken) predicate.File {
return file.IDGT(token.ID)
},
},
}
func getFileCursorQuery(args *ListFileParameters, token *PageToken, query []*ent.FileQuery) []*ent.FileQuery {
o := &sql.OrderTermOptions{}
getOrderTerm(args.Order)(o)
predicates, ok := fileCursorQuery[args.OrderBy]
if !ok {
predicates = fileCursorQuery[file.FieldID]
}
// If all folder is already listed in previous page, only query for files.
if token != nil && token.StartWithFile && !args.MixedType {
query = query[1:2]
}
// Mixing folders and files with one query
if args.MixedType {
query = query[2:]
} else if args.FolderOnly {
query = query[0:1]
}
if token != nil {
query[0].Where(predicates[o.Desc](token))
}
return query
}
// getFileNextPageToken returns the next page token for the given last file.
func getFileNextPageToken(hasher hashid.Encoder, last *ent.File, args *ListFileParameters, nextStartWithFile bool) (string, error) {
token := &PageToken{
ID: last.ID,
StartWithFile: nextStartWithFile,
}
switch args.OrderBy {
case file.FieldName:
token.String = last.Name
case file.FieldSize:
token.Int = int(last.Size)
case file.FieldUpdatedAt:
token.Time = &last.UpdatedAt
}
return token.Encode(hasher, hashid.EncodeFileID)
}

170
inventory/group.go Normal file
View File

@@ -0,0 +1,170 @@
package inventory
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/group"
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
)
type (
// Ctx keys for eager loading options.
LoadGroupPolicy struct{}
)
const (
AnonymousGroupID = 3
)
type (
GroupClient interface {
TxOperator
// AnonymousGroup returns the anonymous group.
AnonymousGroup(ctx context.Context) (*ent.Group, error)
// ListAll returns all groups.
ListAll(ctx context.Context) ([]*ent.Group, error)
// GetByID returns the group by id.
GetByID(ctx context.Context, id int) (*ent.Group, error)
// ListGroups returns a list of groups with pagination.
ListGroups(ctx context.Context, args *ListGroupParameters) (*ListGroupResult, error)
// CountUsers returns the number of users in the group.
CountUsers(ctx context.Context, id int) (int, error)
// Upsert upserts a group.
Upsert(ctx context.Context, group *ent.Group) (*ent.Group, error)
// Delete deletes a group.
Delete(ctx context.Context, id int) error
}
ListGroupParameters struct {
*PaginationArgs
}
ListGroupResult struct {
*PaginationResults
Groups []*ent.Group
}
)
func NewGroupClient(client *ent.Client, dbType conf.DBType, cache cache.Driver) GroupClient {
return &groupClient{client: client, maxSQlParam: sqlParamLimit(dbType), cache: cache}
}
type groupClient struct {
client *ent.Client
cache cache.Driver
maxSQlParam int
}
func (c *groupClient) SetClient(newClient *ent.Client) TxOperator {
return &groupClient{client: newClient, maxSQlParam: c.maxSQlParam, cache: c.cache}
}
func (c *groupClient) GetClient() *ent.Client {
return c.client
}
func (c *groupClient) CountUsers(ctx context.Context, id int) (int, error) {
return c.client.Group.Query().Where(group.ID(id)).QueryUsers().Count(ctx)
}
func (c *groupClient) AnonymousGroup(ctx context.Context) (*ent.Group, error) {
return withGroupEagerLoading(ctx, c.client.Group.Query().Where(group.ID(AnonymousGroupID))).First(ctx)
}
func (c *groupClient) ListAll(ctx context.Context) ([]*ent.Group, error) {
return withGroupEagerLoading(ctx, c.client.Group.Query()).All(ctx)
}
func (c *groupClient) Upsert(ctx context.Context, group *ent.Group) (*ent.Group, error) {
if group.ID == 0 {
return c.client.Group.Create().
SetName(group.Name).
SetMaxStorage(group.MaxStorage).
SetSpeedLimit(group.SpeedLimit).
SetPermissions(group.Permissions).
SetSettings(group.Settings).
SetStoragePoliciesID(group.Edges.StoragePolicies.ID).
Save(ctx)
}
res, err := c.client.Group.UpdateOne(group).
SetName(group.Name).
SetMaxStorage(group.MaxStorage).
SetSpeedLimit(group.SpeedLimit).
SetPermissions(group.Permissions).
SetSettings(group.Settings).
ClearStoragePolicies().
SetStoragePoliciesID(group.Edges.StoragePolicies.ID).
Save(ctx)
if err != nil {
return nil, err
}
return res, nil
}
func (c *groupClient) Delete(ctx context.Context, id int) error {
if err := c.client.Group.DeleteOneID(id).Exec(ctx); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
}
func (c *groupClient) ListGroups(ctx context.Context, args *ListGroupParameters) (*ListGroupResult, error) {
query := withGroupEagerLoading(ctx, c.client.Group.Query())
pageSize := capPageSize(c.maxSQlParam, args.PageSize, 10)
queryWithoutOrder := query.Clone()
query.Order(getGroupOrderOption(args)...)
// Count total items
total, err := queryWithoutOrder.Clone().
Count(ctx)
if err != nil {
return nil, err
}
groups, err := query.Clone().
Limit(pageSize).
Offset(args.Page * pageSize).
All(ctx)
if err != nil {
return nil, err
}
return &ListGroupResult{
Groups: groups,
PaginationResults: &PaginationResults{
TotalItems: total,
Page: args.Page,
PageSize: pageSize,
},
}, nil
}
func (c *groupClient) GetByID(ctx context.Context, id int) (*ent.Group, error) {
return withGroupEagerLoading(ctx, c.client.Group.Query().Where(group.ID(id))).First(ctx)
}
func getGroupOrderOption(args *ListGroupParameters) []group.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case group.FieldName:
return []group.OrderOption{group.ByName(orderTerm), group.ByID(orderTerm)}
case group.FieldMaxStorage:
return []group.OrderOption{group.ByMaxStorage(orderTerm), group.ByID(orderTerm)}
default:
return []group.OrderOption{group.ByID(orderTerm)}
}
}
func withGroupEagerLoading(ctx context.Context, q *ent.GroupQuery) *ent.GroupQuery {
if _, ok := ctx.Value(LoadGroupPolicy{}).(bool); ok {
q.WithStoragePolicies(func(spq *ent.StoragePolicyQuery) {
withStoragePolicyEagerLoading(ctx, spq)
})
}
return q
}

156
inventory/node.go Normal file
View File

@@ -0,0 +1,156 @@
package inventory
import (
"context"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/node"
)
type (
LoadNodeStoragePolicy struct{}
NodeClient interface {
TxOperator
// ListActiveNodes returns the active nodes.
ListActiveNodes(ctx context.Context, subset []int) ([]*ent.Node, error)
// ListNodes returns the nodes with pagination.
ListNodes(ctx context.Context, args *ListNodeParameters) (*ListNodeResult, error)
// GetNodeById returns the node by id.
GetNodeById(ctx context.Context, id int) (*ent.Node, error)
// GetNodeByIds returns the nodes by ids.
GetNodeByIds(ctx context.Context, ids []int) ([]*ent.Node, error)
// Upsert upserts a node.
Upsert(ctx context.Context, n *ent.Node) (*ent.Node, error)
// Delete deletes a node.
Delete(ctx context.Context, id int) error
}
ListNodeParameters struct {
*PaginationArgs
Status node.Status
}
ListNodeResult struct {
*PaginationResults
Nodes []*ent.Node
}
)
func NewNodeClient(client *ent.Client) NodeClient {
return &nodeClient{
client: client,
}
}
type nodeClient struct {
client *ent.Client
}
func (c *nodeClient) SetClient(newClient *ent.Client) TxOperator {
return &nodeClient{client: newClient}
}
func (c *nodeClient) GetClient() *ent.Client {
return c.client
}
func (c *nodeClient) ListActiveNodes(ctx context.Context, subset []int) ([]*ent.Node, error) {
stm := c.client.Node.Query().Where(node.StatusEQ(node.StatusActive))
if len(subset) > 0 {
stm = stm.Where(node.IDIn(subset...))
}
return stm.All(ctx)
}
func (c *nodeClient) GetNodeByIds(ctx context.Context, ids []int) ([]*ent.Node, error) {
return withNodeEagerLoading(ctx, c.client.Node.Query().Where(node.IDIn(ids...))).All(ctx)
}
func (c *nodeClient) GetNodeById(ctx context.Context, id int) (*ent.Node, error) {
return withNodeEagerLoading(ctx, c.client.Node.Query().Where(node.IDEQ(id))).First(ctx)
}
func (c *nodeClient) Delete(ctx context.Context, id int) error {
return c.client.Node.DeleteOneID(id).Exec(ctx)
}
func (c *nodeClient) ListNodes(ctx context.Context, args *ListNodeParameters) (*ListNodeResult, error) {
query := c.client.Node.Query()
if string(args.Status) != "" {
query = query.Where(node.StatusEQ(args.Status))
}
query.Order(getNodeOrderOption(args)...)
// Count total items
total, err := query.Clone().
Count(ctx)
if err != nil {
return nil, err
}
nodes, err := withNodeEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, err
}
return &ListNodeResult{
PaginationResults: &PaginationResults{
TotalItems: total,
Page: args.Page,
PageSize: args.PageSize,
},
Nodes: nodes,
}, nil
}
func (c *nodeClient) Upsert(ctx context.Context, n *ent.Node) (*ent.Node, error) {
if n.ID == 0 {
return c.client.Node.Create().
SetName(n.Name).
SetServer(n.Server).
SetSlaveKey(n.SlaveKey).
SetStatus(n.Status).
SetType(node.TypeSlave).
SetSettings(n.Settings).
SetCapabilities(n.Capabilities).
SetWeight(n.Weight).
Save(ctx)
}
res, err := c.client.Node.UpdateOne(n).
SetName(n.Name).
SetServer(n.Server).
SetSlaveKey(n.SlaveKey).
SetStatus(n.Status).
SetSettings(n.Settings).
SetCapabilities(n.Capabilities).
SetWeight(n.Weight).
Save(ctx)
if err != nil {
return nil, err
}
return res, nil
}
func getNodeOrderOption(args *ListNodeParameters) []node.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case node.FieldName:
return []node.OrderOption{node.ByName(orderTerm), node.ByID(orderTerm)}
case node.FieldWeight:
return []node.OrderOption{node.ByWeight(orderTerm), node.ByID(orderTerm)}
case node.FieldUpdatedAt:
return []node.OrderOption{node.ByUpdatedAt(orderTerm), node.ByID(orderTerm)}
default:
return []node.OrderOption{node.ByID(orderTerm)}
}
}
func withNodeEagerLoading(ctx context.Context, query *ent.NodeQuery) *ent.NodeQuery {
if _, ok := ctx.Value(LoadNodeStoragePolicy{}).(bool); ok {
query = query.WithStoragePolicy(func(gq *ent.StoragePolicyQuery) {
withStoragePolicyEagerLoading(ctx, gq)
})
}
return query
}

243
inventory/policy.go Normal file
View File

@@ -0,0 +1,243 @@
package inventory
import (
"context"
"encoding/gob"
"fmt"
"strconv"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/storagepolicy"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
)
const (
// StoragePolicyCacheKey is the cache key of storage policy.
StoragePolicyCacheKey = "storage_policy_"
)
func init() {
gob.Register(ent.StoragePolicy{})
gob.Register([]ent.StoragePolicy{})
}
type (
LoadStoragePolicyGroup struct{}
SkipStoragePolicyCache struct{}
StoragePolicyClient interface {
// GetByGroup returns the storage policies of the group.
GetByGroup(ctx context.Context, group *ent.Group) (*ent.StoragePolicy, error)
// GetPolicyByID returns the storage policy by id.
GetPolicyByID(ctx context.Context, id int) (*ent.StoragePolicy, error)
// UpdateAccessKey updates the access key of the storage policy. It also clear related cache in KV.
UpdateAccessKey(ctx context.Context, policy *ent.StoragePolicy, token string) error
// ListPolicyByType returns the storage policies by type.
ListPolicyByType(ctx context.Context, t types.PolicyType) ([]*ent.StoragePolicy, error)
// ListPolicies returns the storage policies with pagination.
ListPolicies(ctx context.Context, args *ListPolicyParameters) (*ListPolicyResult, error)
// Upsert upserts the storage policy.
Upsert(ctx context.Context, policy *ent.StoragePolicy) (*ent.StoragePolicy, error)
// Delete deletes the storage policy.
Delete(ctx context.Context, policy *ent.StoragePolicy) error
}
ListPolicyParameters struct {
*PaginationArgs
Type types.PolicyType
}
ListPolicyResult struct {
*PaginationResults
Policies []*ent.StoragePolicy
}
)
// NewStoragePolicyClient returns a new StoragePolicyClient.
func NewStoragePolicyClient(client *ent.Client, cache cache.Driver) StoragePolicyClient {
return &storagePolicyClient{client: client, cache: cache}
}
type storagePolicyClient struct {
client *ent.Client
cache cache.Driver
}
func (c *storagePolicyClient) Delete(ctx context.Context, policy *ent.StoragePolicy) error {
if err := c.client.StoragePolicy.DeleteOne(policy).Exec(ctx); err != nil {
return fmt.Errorf("failed to delete storage policy: %w", err)
}
// Clear cache
if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil {
return fmt.Errorf("failed to clear storage policy cache: %w", err)
}
return nil
}
func (c *storagePolicyClient) Upsert(ctx context.Context, policy *ent.StoragePolicy) (*ent.StoragePolicy, error) {
var nodeId *int
if policy.NodeID != 0 {
nodeId = &policy.NodeID
}
if policy.ID == 0 {
p, err := c.client.StoragePolicy.
Create().
SetName(policy.Name).
SetType(policy.Type).
SetServer(policy.Server).
SetBucketName(policy.BucketName).
SetIsPrivate(policy.IsPrivate).
SetAccessKey(policy.AccessKey).
SetSecretKey(policy.SecretKey).
SetMaxSize(policy.MaxSize).
SetDirNameRule(policy.DirNameRule).
SetFileNameRule(policy.FileNameRule).
SetSettings(policy.Settings).
SetNillableNodeID(nodeId).
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create storage policy: %w", err)
}
return p, nil
}
updateQuery := c.client.StoragePolicy.UpdateOne(policy).
SetName(policy.Name).
SetType(policy.Type).
SetServer(policy.Server).
SetBucketName(policy.BucketName).
SetIsPrivate(policy.IsPrivate).
SetSecretKey(policy.SecretKey).
SetMaxSize(policy.MaxSize).
SetDirNameRule(policy.DirNameRule).
SetFileNameRule(policy.FileNameRule).
SetSettings(policy.Settings).
SetNillableNodeID(nodeId)
if policy.Type != types.PolicyTypeOd {
updateQuery.SetAccessKey(policy.AccessKey)
}
p, err := updateQuery.Save(ctx)
// Clear cache
if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil {
return nil, fmt.Errorf("failed to clear storage policy cache: %w", err)
}
if err != nil {
return nil, fmt.Errorf("failed to update storage policy: %w", err)
}
return p, nil
}
func (c *storagePolicyClient) GetByGroup(ctx context.Context, group *ent.Group) (*ent.StoragePolicy, error) {
val, skipCache := ctx.Value(SkipStoragePolicyCache{}).(bool)
skipCache = skipCache && val
res, err := withStoragePolicyEagerLoading(ctx, c.client.Group.QueryStoragePolicies(group)).WithNode().First(ctx)
if err != nil {
return nil, fmt.Errorf("get storage policies: %w", err)
}
return res, nil
}
// GetPolicyByID returns the storage policy by id.
func (c *storagePolicyClient) GetPolicyByID(ctx context.Context, id int) (*ent.StoragePolicy, error) {
val, skipCache := ctx.Value(SkipStoragePolicyCache{}).(bool)
skipCache = skipCache && val
// Try to read from cache
if c.cache != nil && !skipCache {
if res, ok := c.cache.Get(StoragePolicyCacheKey + strconv.Itoa(id)); ok {
cached := res.(ent.StoragePolicy)
return &cached, nil
}
}
res, err := withStoragePolicyEagerLoading(ctx, c.client.StoragePolicy.Query().Where(storagepolicy.ID(id))).WithNode().First(ctx)
if err != nil {
return nil, fmt.Errorf("get storage policy: %w", err)
}
// Write to cache
if c.cache != nil && !skipCache {
_ = c.cache.Set(StoragePolicyCacheKey+strconv.Itoa(id), *res, -1)
}
return res, nil
}
func (c *storagePolicyClient) ListPolicyByType(ctx context.Context, t types.PolicyType) ([]*ent.StoragePolicy, error) {
policies, err := c.client.StoragePolicy.Query().Where(storagepolicy.TypeEQ(string(t))).All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list storage policies: %w", err)
}
return policies, nil
}
func (c *storagePolicyClient) UpdateAccessKey(ctx context.Context, policy *ent.StoragePolicy, token string) error {
_, err := c.client.StoragePolicy.UpdateOne(policy).SetAccessKey(token).Save(ctx)
if err != nil {
return fmt.Errorf("faield to update access key in DB: %w", err)
}
// Clear cache
if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil {
return fmt.Errorf("failed to clear storage policy cache: %w", err)
}
return nil
}
func (c *storagePolicyClient) ListPolicies(ctx context.Context, args *ListPolicyParameters) (*ListPolicyResult, error) {
query := c.client.StoragePolicy.Query().WithNode()
if args.Type != "" {
query = query.Where(storagepolicy.TypeEQ(string(args.Type)))
}
query.Order(getStoragePolicyOrderOption(args)...)
// Count total items
total, err := query.Clone().
Count(ctx)
if err != nil {
return nil, err
}
policies, err := withStoragePolicyEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, err
}
return &ListPolicyResult{
PaginationResults: &PaginationResults{
TotalItems: total,
Page: args.Page,
PageSize: args.PageSize,
},
Policies: policies,
}, nil
}
func getStoragePolicyOrderOption(args *ListPolicyParameters) []storagepolicy.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case storagepolicy.FieldUpdatedAt:
return []storagepolicy.OrderOption{storagepolicy.ByUpdatedAt(orderTerm), storagepolicy.ByID(orderTerm)}
default:
return []storagepolicy.OrderOption{storagepolicy.ByID(orderTerm)}
}
}
func withStoragePolicyEagerLoading(ctx context.Context, query *ent.StoragePolicyQuery) *ent.StoragePolicyQuery {
if _, ok := ctx.Value(LoadStoragePolicyGroup{}).(bool); ok {
query = query.WithGroups(func(gq *ent.GroupQuery) {
withGroupEagerLoading(ctx, gq)
})
}
return query
}

248
inventory/setting.go Normal file

File diff suppressed because one or more lines are too long

401
inventory/share.go Normal file
View File

@@ -0,0 +1,401 @@
package inventory
import (
"context"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
"github.com/cloudreve/Cloudreve/v4/ent/share"
"github.com/cloudreve/Cloudreve/v4/ent/user"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/samber/lo"
)
type (
// Ctx keys for eager loading options.
LoadShareFile struct{}
LoadShareUser struct{}
)
var (
ErrShareLinkExpired = fmt.Errorf("share link expired")
ErrOwnerInactive = fmt.Errorf("owner is inactive")
ErrSourceFileInvalid = fmt.Errorf("source file is deleted")
)
type (
ShareClient interface {
TxOperator
// GetByID returns the share with given id.
GetByID(ctx context.Context, id int) (*ent.Share, error)
// GetByIDUser returns the share with given id and user id.
GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error)
// GetByHashID returns the share with given hash id.
GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error)
// Upsert creates or update a new share record.
Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error)
// Viewed increase the view count of the share.
Viewed(ctx context.Context, share *ent.Share) error
// Downloaded increase the download count of the share.
Downloaded(ctx context.Context, share *ent.Share) error
// Delete deletes the share.
Delete(ctx context.Context, shareId int) error
// List returns a list of shares with the given args.
List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error)
// CountByTimeRange counts the number of shares created in the given time range.
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
// DeleteBatch deletes the shares with the given ids.
DeleteBatch(ctx context.Context, shareIds []int) error
}
CreateShareParams struct {
Existed *ent.Share
Password string
RemainDownloads int
Expires *time.Time
OwnerID int
FileID int
}
ListShareArgs struct {
*PaginationArgs
UserID int
FileID int
PublicOnly bool
}
ListShareResult struct {
*PaginationResults
Shares []*ent.Share
}
)
func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient {
return &shareClient{
client: client,
hasher: hasher,
maxSQlParam: sqlParamLimit(dbType),
}
}
type shareClient struct {
maxSQlParam int
client *ent.Client
hasher hashid.Encoder
}
func (c *shareClient) SetClient(newClient *ent.Client) TxOperator {
return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
}
func (c *shareClient) GetClient() *ent.Client {
return c.client
}
func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
if start == nil || end == nil {
return c.client.Share.Query().Count(ctx)
}
return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx)
}
func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) {
if params.Existed != nil {
createQuery := c.client.Share.
UpdateOne(params.Existed)
if params.RemainDownloads > 0 {
createQuery.SetRemainDownloads(params.RemainDownloads)
} else {
createQuery.ClearRemainDownloads()
}
if params.Expires != nil {
createQuery.SetNillableExpires(params.Expires)
} else {
createQuery.ClearExpires()
}
return createQuery.Save(ctx)
}
query := c.client.Share.
Create().
SetUserID(params.OwnerID).
SetFileID(params.FileID)
if params.Password != "" {
query.SetPassword(params.Password)
}
if params.RemainDownloads > 0 {
query.SetRemainDownloads(params.RemainDownloads)
}
if params.Expires != nil {
query.SetNillableExpires(params.Expires)
}
return query.Save(ctx)
}
func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) {
id, err := c.hasher.Decode(idRaw, hashid.ShareID)
if err != nil {
return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err)
}
return c.GetByID(ctx, id)
}
func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
}
return s, nil
}
func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().
Where(share.ID(id))).
Where(share.HasUserWith(user.ID(uid))).First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
}
return s, nil
}
func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error {
_, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx)
return err
}
func (c *shareClient) Delete(ctx context.Context, shareId int) error {
return c.client.Share.DeleteOneID(shareId).Exec(ctx)
}
// Viewed increments the view count of the share.
func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error {
_, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx)
return err
}
// Downloaded increments the download count of the share.
func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error {
stm := c.client.Share.
UpdateOneID(share.ID).
AddDownloads(1)
if share.RemainDownloads != nil && *share.RemainDownloads >= 0 {
stm.AddRemainDownloads(-1)
}
_, err := stm.Save(ctx)
return err
}
func IsValidShare(share *ent.Share) error {
// Check if share is expired
if err := IsShareExpired(share); err != nil {
return err
}
// Check owner status
owner, err := share.Edges.UserOrErr()
if err != nil || owner.Status != user.StatusActive {
// Owner already deleted, or not active.
return ErrOwnerInactive
}
// Check source file status
file, err := share.Edges.FileOrErr()
if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID {
// Source file already deleted
return ErrSourceFileInvalid
}
return nil
}
func IsShareExpired(share *ent.Share) error {
// Check if share is expired
if (share.Expires != nil && share.Expires.Before(time.Now())) ||
(share.RemainDownloads != nil && *share.RemainDownloads <= 0) {
return ErrShareLinkExpired
}
return nil
}
func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) {
rawQuery := c.listQuery(args)
query := withShareEagerLoading(ctx, rawQuery)
var (
shares []*ent.Share
err error
paginationRes *PaginationResults
)
if args.UseCursorPagination {
shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
} else {
shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10)
}
if err != nil {
return nil, fmt.Errorf("query failed with paginiation: %w", err)
}
return &ListShareResult{
Shares: shares,
PaginationResults: paginationRes,
}, nil
}
func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getShareOrderOption(args)...)
var (
pageToken *PageToken
err error
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
}
queryPaged := getShareCursorQuery(args, pageToken, query)
// Use page size + 1 to determine if there are more items to come
queryPaged.Limit(pageSize + 1)
logs, err := queryPaged.
All(ctx)
if err != nil {
return nil, nil, err
}
// More items to come
nextTokenStr := ""
if len(logs) > pageSize {
lastItem := logs[len(logs)-2]
nextToken, err := getShareNextPageToken(c.hasher, lastItem, args)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getShareOrderOption(args)...)
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, nil, err
}
return logs, &PaginationResults{
PageSize: pageSize,
TotalItems: total,
Page: args.Page,
}, nil
}
func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery {
query := c.client.Share.Query()
if args.UserID > 0 {
query.Where(share.HasUserWith(user.ID(args.UserID)))
}
if args.PublicOnly {
query.Where(share.PasswordIsNil())
}
if args.FileID > 0 {
query.Where(share.HasFileWith(file.ID(args.FileID)))
}
return query
}
// getShareNextPageToken returns the next page token for the given last share.
func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) {
token := &PageToken{
ID: last.ID,
}
return token.Encode(hasher, hashid.EncodeShareID)
}
func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery {
o := &sql.OrderTermOptions{}
getOrderTerm(args.Order)(o)
predicates, ok := shareCursorQuery[args.OrderBy]
if !ok {
predicates = shareCursorQuery[share.FieldID]
}
if token != nil {
query.Where(predicates[o.Desc](token))
}
return query
}
var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{
share.FieldID: {
true: func(token *PageToken) predicate.Share {
return share.IDLT(token.ID)
},
false: func(token *PageToken) predicate.Share {
return share.IDGT(token.ID)
},
},
}
func getShareOrderOption(args *ListShareArgs) []share.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case share.FieldViews:
return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)}
case share.FieldDownloads:
return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)}
case share.FieldRemainDownloads:
return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)}
default:
return []share.OrderOption{share.ByID(orderTerm)}
}
}
func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery {
if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v {
q.WithFile(func(q *ent.FileQuery) {
withFileEagerLoading(ctx, q)
})
}
if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v {
q.WithUser(func(q *ent.UserQuery) {
withUserEagerLoading(ctx, q)
})
}
return q
}

314
inventory/task.go Normal file
View File

@@ -0,0 +1,314 @@
package inventory
import (
"context"
"fmt"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/task"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/gofrs/uuid"
"github.com/samber/lo"
)
type (
// Ctx keys for eager loading options.
LoadTaskUser struct{}
TaskArgs struct {
Status task.Status
Type string
PublicState *types.TaskPublicState
PrivateState string
OwnerID int
CorrelationID uuid.UUID
}
)
type TaskClient interface {
TxOperator
// New creates a new task with the given args.
New(ctx context.Context, task *TaskArgs) (*ent.Task, error)
// Update updates the task with the given args.
Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error)
// GetPendingTasks returns all pending tasks of given type.
GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error)
// GetTaskByID returns the task with the given ID.
GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error)
// SetCompleteByID sets the task with the given ID to complete.
SetCompleteByID(ctx context.Context, taskID int) error
// List returns a list of tasks with the given args.
List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error)
// DeleteByIDs deletes the tasks with the given IDs.
DeleteByIDs(ctx context.Context, ids ...int) error
}
type (
ListTaskArgs struct {
*PaginationArgs
Types []string
Status []task.Status
UserID int
CorrelationID *uuid.UUID
}
ListTaskResult struct {
*PaginationResults
Tasks []*ent.Task
}
)
func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient {
return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher}
}
type taskClient struct {
maxSQlParam int
hasher hashid.Encoder
client *ent.Client
}
func (c *taskClient) SetClient(newClient *ent.Client) TxOperator {
return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher}
}
func (c *taskClient) GetClient() *ent.Client {
return c.client
}
func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) {
stm := c.client.Task.
Create().
SetType(task.Type).
SetPublicState(task.PublicState)
if task.PrivateState != "" {
stm.SetPrivateState(task.PrivateState)
}
if task.OwnerID != 0 {
stm.SetUserID(task.OwnerID)
}
if task.Status != "" {
stm.SetStatus(task.Status)
}
if task.CorrelationID.String() != uuid.Nil.String() {
stm.SetCorrelationID(task.CorrelationID)
}
newTask, err := stm.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create task: %w", err)
}
return newTask, nil
}
func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error {
_, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx)
return err
}
func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) {
stm := c.client.Task.UpdateOne(task).
SetPublicState(args.PublicState)
task.PublicState = args.PublicState
if task.PrivateState != "" {
stm.SetPrivateState(task.PrivateState)
task.PrivateState = args.PrivateState
}
if task.Status != "" {
stm.SetStatus(args.Status)
task.Status = args.Status
}
if err := stm.Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to create task: %w", err)
}
return task, nil
}
func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) {
tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()).
Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)).
Where(task.TypeIn(taskType...)).
All(ctx)
if err != nil {
return nil, err
}
// Anonymous user is not loaded by default, so we need to load it manually.
userClient := NewUserClient(c.client)
anonymous, err := userClient.AnonymousUser(ctx)
for _, t := range tasks {
if t.UserTasks == 0 {
if err != nil {
return nil, err
}
t.SetUser(anonymous)
}
}
return tasks, nil
}
func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) {
return withTaskEagerLoading(ctx, c.client.Task.Query()).
Where(task.ID(taskID)).
First(ctx)
}
func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error {
_, err := c.client.Task.UpdateOneID(taskID).
SetStatus(task.StatusCompleted).
Save(ctx)
return err
}
func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) {
q := c.client.Task.Query()
if args.UserID != 0 {
q.Where(task.UserTasks(args.UserID))
}
if args.Types != nil {
q.Where(task.TypeIn(args.Types...))
}
if args.Status != nil {
q.Where(task.StatusIn(args.Status...))
}
if args.CorrelationID != nil {
q.Where(task.CorrelationID(*args.CorrelationID))
}
q = withTaskEagerLoading(ctx, q)
var (
tasks []*ent.Task
err error
paginationRes *PaginationResults
)
if args.UseCursorPagination {
tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1)
} else {
tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1)
}
if err != nil {
return nil, fmt.Errorf("query failed with paginiation: %w", err)
}
return &ListTaskResult{
Tasks: tasks,
PaginationResults: paginationRes,
}, nil
}
func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(task.ByID(sql.OrderDesc()))
var (
pageToken *PageToken
err error
queryPaged = query
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
queryPaged = query.Where(task.IDLT(pageToken.ID))
}
// Use page size + 1 to determine if there are more items to come
queryPaged.Limit(pageSize + 1)
tasks, err := queryPaged.
All(ctx)
if err != nil {
return nil, nil, err
}
// More items to come
nextTokenStr := ""
if len(tasks) > pageSize {
lastItem := tasks[len(tasks)-2]
nextToken, err := getTaskNextPageToken(c.hasher, lastItem)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getTaskOrderOption(args)...)
// Count total items
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
logs, err := query.
Limit(pageSize).
Offset(args.Page * args.PageSize).
All(ctx)
if err != nil {
return nil, nil, err
}
return logs, &PaginationResults{
PageSize: pageSize,
TotalItems: total,
Page: args.Page,
}, nil
}
func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
default:
return []task.OrderOption{task.ByID(orderTerm)}
}
}
// getTaskNextPageToken returns the next page token for the given last task.
func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) {
token := &PageToken{
ID: last.ID,
}
return token.Encode(hasher, hashid.EncodeTaskID)
}
func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery {
if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v {
q.WithUser(func(q *ent.UserQuery) {
withUserEagerLoading(ctx, q)
})
}
return q
}

101
inventory/tx.go Normal file
View File

@@ -0,0 +1,101 @@
package inventory
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
)
type TxOperator interface {
SetClient(newClient *ent.Client) TxOperator
GetClient() *ent.Client
}
type (
Tx struct {
tx *ent.Tx
parent *Tx
inherited bool
finished bool
storageDiff StorageDiff
}
// TxCtx is the context key for inherited transaction
TxCtx struct{}
)
// AppendStorageDiff appends the given storage diff to the transaction.
func (t *Tx) AppendStorageDiff(diff StorageDiff) {
root := t
for root.inherited {
root = root.parent
}
if root.storageDiff == nil {
root.storageDiff = diff
} else {
root.storageDiff.Merge(diff)
}
}
// WithTx wraps the given inventory client with a transaction.
func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, error) {
var txClient *ent.Client
var txWrapper *Tx
if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished {
txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited}
} else {
tx, err := c.GetClient().Tx(ctx)
if err != nil {
return c, nil, ctx, fmt.Errorf("failed to create transaction: %w", err)
}
txWrapper = &Tx{inherited: false, tx: tx}
ctx = context.WithValue(ctx, TxCtx{}, txWrapper)
}
txClient = txWrapper.tx.Client()
return c.SetClient(txClient).(T), txWrapper, ctx, nil
}
func Rollback(tx *Tx) error {
if !tx.inherited {
tx.finished = true
return tx.tx.Rollback()
}
return nil
}
func commit(tx *Tx) (bool, error) {
if !tx.inherited {
tx.finished = true
return true, tx.tx.Commit()
}
return false, nil
}
func Commit(tx *Tx) error {
_, err := commit(tx)
return err
}
// CommitWithStorageDiff commits the transaction and applies the storage diff, only if the transaction is not inherited.
func CommitWithStorageDiff(ctx context.Context, tx *Tx, l logging.Logger, uc UserClient) error {
commited, err := commit(tx)
if err != nil {
return err
}
if !commited {
return nil
}
if err := uc.ApplyStorageDiff(ctx, tx.storageDiff); err != nil {
l.Error("Failed to apply storage diff", "error", err)
}
return nil
}

224
inventory/types/types.go Normal file
View File

@@ -0,0 +1,224 @@
package types
import (
"time"
)
// UserSetting 用户其他配置
type (
UserSetting struct {
ProfileOff bool `json:"profile_off,omitempty"`
PreferredTheme string `json:"preferred_theme,omitempty"`
VersionRetention bool `json:"version_retention,omitempty"`
VersionRetentionExt []string `json:"version_retention_ext,omitempty"`
VersionRetentionMax int `json:"version_retention_max,omitempty"`
Pined []PinedFile `json:"pined,omitempty"`
Language string `json:"email_language,omitempty"`
}
PinedFile struct {
Uri string `json:"uri"`
Name string `json:"name,omitempty"`
}
// GroupSetting 用户组其他配置
GroupSetting struct {
CompressSize int64 `json:"compress_size,omitempty"` // 可压缩大小
DecompressSize int64 `json:"decompress_size,omitempty"`
RemoteDownloadOptions map[string]interface{} `json:"remote_download_options,omitempty"` // 离线下载用户组配置
SourceBatchSize int `json:"source_batch,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
MaxWalkedFiles int `json:"max_walked_files,omitempty"`
TrashRetention int `json:"trash_retention,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
}
// PolicySetting 非公有的存储策略属性
PolicySetting struct {
// Upyun访问Token
Token string `json:"token"`
// 允许的文件扩展名
FileType []string `json:"file_type"`
// OauthRedirect Oauth 重定向地址
OauthRedirect string `json:"od_redirect,omitempty"`
// CustomProxy whether to use custom-proxy to get file content
CustomProxy bool `json:"custom_proxy,omitempty"`
// ProxyServer 反代地址
ProxyServer string `json:"proxy_server,omitempty"`
// InternalProxy whether to use Cloudreve internal proxy to get file content
InternalProxy bool `json:"internal_proxy,omitempty"`
// OdDriver OneDrive 驱动器定位符
OdDriver string `json:"od_driver,omitempty"`
// Region 区域代码
Region string `json:"region,omitempty"`
// ServerSideEndpoint 服务端请求使用的 Endpoint为空时使用 Policy.Server 字段
ServerSideEndpoint string `json:"server_side_endpoint,omitempty"`
// 分片上传的分片大小
ChunkSize int64 `json:"chunk_size,omitempty"`
// 每秒对存储端的 API 请求上限
TPSLimit float64 `json:"tps_limit,omitempty"`
// 每秒 API 请求爆发上限
TPSLimitBurst int `json:"tps_limit_burst,omitempty"`
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY `
S3ForcePathStyle bool `json:"s3_path_style"`
// File extensions that support thumbnail generation using native policy API.
ThumbExts []string `json:"thumb_exts,omitempty"`
// Whether to support all file extensions for thumbnail generation.
ThumbSupportAllExts bool `json:"thumb_support_all_exts,omitempty"`
// ThumbMaxSize indicates the maximum allowed size of a thumbnail. 0 indicates that no limit is set.
ThumbMaxSize int64 `json:"thumb_max_size,omitempty"`
// Whether to upload file through server's relay.
Relay bool `json:"relay,omitempty"`
// Whether to pre allocate space for file before upload in physical disk.
PreAllocate bool `json:"pre_allocate,omitempty"`
// MediaMetaExts file extensions that support media meta generation using native policy API.
MediaMetaExts []string `json:"media_meta_exts,omitempty"`
// MediaMetaGeneratorProxy whether to use local proxy to generate media meta.
MediaMetaGeneratorProxy bool `json:"media_meta_generator_proxy,omitempty"`
// ThumbGeneratorProxy whether to use local proxy to generate thumbnail.
ThumbGeneratorProxy bool `json:"thumb_generator_proxy,omitempty"`
// NativeMediaProcessing whether to use native media processing API from storage provider.
NativeMediaProcessing bool `json:"native_media_processing"`
// S3DeleteBatchSize the number of objects to delete in each batch.
S3DeleteBatchSize int `json:"s3_delete_batch_size,omitempty"`
// StreamSaver whether to use stream saver to download file in Web.
StreamSaver bool `json:"stream_saver,omitempty"`
// UseCname whether to use CNAME for endpoint (OSS).
UseCname bool `json:"use_cname,omitempty"`
// CDN domain does not need to be signed.
SourceAuth bool `json:"source_auth,omitempty"`
}
FileType int
EntityType int
GroupPermission int
FilePermission int
DavAccountOption int
NodeCapability int
NodeSetting struct {
Provider DownloaderProvider `json:"provider,omitempty"`
*QBittorrentSetting `json:"qbittorrent,omitempty"`
*Aria2Setting `json:"aria2,omitempty"`
// 下载监控间隔
Interval int `json:"interval,omitempty"`
WaitForSeeding bool `json:"wait_for_seeding,omitempty"`
}
DownloaderProvider string
QBittorrentSetting struct {
Server string `json:"server,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
Options map[string]any `json:"options,omitempty"`
TempPath string `json:"temp_path,omitempty"`
}
Aria2Setting struct {
Server string `json:"server,omitempty"`
Token string `json:"token,omitempty"`
Options map[string]any `json:"options,omitempty"`
TempPath string `json:"temp_path,omitempty"`
}
TaskPublicState struct {
Error string `json:"error,omitempty"`
ErrorHistory []string `json:"error_history,omitempty"`
ExecutedDuration time.Duration `json:"executed_duration,omitempty"`
RetryCount int `json:"retry_count,omitempty"`
ResumeTime int64 `json:"resume_time,omitempty"`
SlaveTaskProps *SlaveTaskProps `json:"slave_task_props,omitempty"`
}
SlaveTaskProps struct {
NodeID int `json:"node_id,omitempty"`
MasterSiteURl string `json:"master_site_u_rl,omitempty"`
MasterSiteID string `json:"master_site_id,omitempty"`
MasterSiteVersion string `json:"master_site_version,omitempty"`
}
EntityRecycleOption struct {
UnlinkOnly bool `json:"unlink_only,omitempty"`
}
DavAccountProps struct {
}
PolicyType string
FileProps struct {
}
)
const (
GroupPermissionIsAdmin = GroupPermission(iota)
GroupPermissionIsAnonymous
GroupPermissionShare
GroupPermissionWebDAV
GroupPermissionArchiveDownload
GroupPermissionArchiveTask
GroupPermissionWebDAVProxy
GroupPermissionShareDownload
GroupPermission_CommunityPlaceholder1
GroupPermissionRemoteDownload
GroupPermission_CommunityPlaceholder2
GroupPermissionRedirectedSource // not used
GroupPermissionAdvanceDelete
GroupPermission_CommunityPlaceholder3
GroupPermission_CommunityPlaceholder4
GroupPermissionSetExplicitUser_placeholder
GroupPermissionIgnoreFileOwnership // not used
)
const (
NodeCapabilityNone NodeCapability = iota
NodeCapabilityCreateArchive
NodeCapabilityExtractArchive
NodeCapabilityRemoteDownload
NodeCapability_CommunityPlaceholder
)
const (
FileTypeFile FileType = iota
FileTypeFolder
)
const (
EntityTypeVersion EntityType = iota
EntityTypeThumbnail
EntityTypeLivePhoto
)
func FileTypeFromString(s string) FileType {
switch s {
case "file":
return FileTypeFile
case "folder":
return FileTypeFolder
}
return -1
}
const (
DavAccountReadOnly DavAccountOption = iota
DavAccountProxy
)
const (
PolicyTypeLocal = "local"
PolicyTypeQiniu = "qiniu"
PolicyTypeUpyun = "upyun"
PolicyTypeOss = "oss"
PolicyTypeCos = "cos"
PolicyTypeS3 = "s3"
PolicyTypeOd = "onedrive"
PolicyTypeRemote = "remote"
PolicyTypeObs = "obs"
)
const (
DownloaderProviderAria2 = DownloaderProvider("aria2")
DownloaderProviderQBittorrent = DownloaderProvider("qbittorrent")
)

594
inventory/user.go Normal file
View File

@@ -0,0 +1,594 @@
package inventory
import (
"context"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"hash"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/davaccount"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/ent/passkey"
"github.com/cloudreve/Cloudreve/v4/ent/schema"
"github.com/cloudreve/Cloudreve/v4/ent/task"
"github.com/cloudreve/Cloudreve/v4/ent/user"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/go-webauthn/webauthn/webauthn"
)
type (
// Ctx keys for eager loading options.
LoadUserGroup struct{}
LoadUserPasskey struct{}
UserCtx struct{}
UserIDCtx struct{}
)
var (
ErrUserEmailExisted = errors.New("user email has been registered")
ErrInactiveUserExisted = errors.New("email already registered but not activated")
ErrorUnknownPasswordType = errors.New("unknown password type")
ErrorIncorrectPassword = errors.New("incorrect password")
ErrInsufficientPoints = errors.New("insufficient points")
)
type (
UserClient interface {
TxOperator
// New creates a new user. If user email registered, existed User will be returned.
Create(ctx context.Context, args *NewUserArgs) (*ent.User, error)
// GetByEmail get the user with given email, user status is ignored.
GetByEmail(ctx context.Context, email string) (*ent.User, error)
// GetByID get user by its ID, user status is ignored.
GetByID(ctx context.Context, id int) (*ent.User, error)
// GetActiveByID get user by its ID, only active user will be returned.
GetActiveByID(ctx context.Context, id int) (*ent.User, error)
// SetStatus Set user to given status
SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error)
// AnonymousUser returns the anonymous user.
AnonymousUser(ctx context.Context) (*ent.User, error)
// GetLoginUserByID returns the login user by its ID. It emits some errors and fallback to anonymous user.
GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error)
// GetLoginUserByEmail returns the login user by its WebDAV credentials.
GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error)
// SaveSettings saves user settings.
SaveSettings(ctx context.Context, u *ent.User) error
// SearchActive search active users by Email or nickname.
SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error)
// ApplyStorageDiff apply storage diff to user.
ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error
// UpdateAvatar updates user avatar.
UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error)
// UpdateNickname updates user nickname.
UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error)
// UpdatePassword updates user password.
UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error)
// UpdateTwoFASecret updates user two factor secret.
UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error)
// ListPasskeys list user's passkeys.
ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error)
// AddPasskey add passkey to user.
AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error)
// RemovePasskey remove passkey from user.
RemovePasskey(ctx context.Context, uid int, keyId string) error
// MarkPasskeyUsed updates passkey used at.
MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error
// CountByTimeRange count users by time range. Will return all records if start or end is nil.
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
// ListUsers list users with pagination.
ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error)
// Upsert upserts a user.
Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error)
// Delete deletes a user.
Delete(ctx context.Context, uid int) error
// CalculateStorage calculate user's storage from scratch and update user's storage.
CalculateStorage(ctx context.Context, uid int) (int64, error)
}
ListUserParameters struct {
*PaginationArgs
GroupID int
Status user.Status
Nick string
Email string
}
ListUserResult struct {
*PaginationResults
Users []*ent.User
}
)
func NewUserClient(client *ent.Client) UserClient {
return &userClient{client: client}
}
type userClient struct {
client *ent.Client
}
type (
// NewUserArgs args to create a new user
NewUserArgs struct {
Email string
Nick string // Optional
PlainPassword string
Status user.Status
GroupID int
Avatar string // Optional
Language string // Optional
}
CreateStoragePackArgs struct {
UserID int
Name string
Size int64
ExpireAt time.Time
}
)
func (c *userClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
if start == nil || end == nil {
return c.client.User.Query().Count(ctx)
}
return c.client.User.Query().Where(user.CreatedAtGTE(*start), user.CreatedAtLT(*end)).Count(ctx)
}
func (c *userClient) UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error) {
return c.client.User.UpdateOne(u).SetNick(name).Save(ctx)
}
func (c *userClient) UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error) {
return c.client.User.UpdateOne(u).SetAvatar(avatar).Save(ctx)
}
func (c *userClient) UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error) {
if secret == "" {
return c.client.User.UpdateOne(u).ClearTwoFactorSecret().Save(ctx)
}
return c.client.User.UpdateOne(u).SetTwoFactorSecret(secret).Save(ctx)
}
func (c *userClient) UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error) {
digest, err := digestPassword(newPassword)
if err != nil {
return nil, err
}
return c.client.User.UpdateOne(u).SetPassword(digest).Save(ctx)
}
func (c *userClient) SetClient(newClient *ent.Client) TxOperator {
return &userClient{client: newClient}
}
func (c *userClient) GetClient() *ent.Client {
return c.client
}
func (c *userClient) ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error) {
return c.client.Passkey.Query().Where(passkey.UserID(uid)).All(ctx)
}
func (c *userClient) AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error) {
return c.client.Passkey.Create().
SetName(name).
SetCredentialID(base64.StdEncoding.EncodeToString(credential.ID)).
SetUserID(uid).
SetCredential(credential).
Save(ctx)
}
func (c *userClient) RemovePasskey(ctx context.Context, uid int, keyId string) error {
ctx = schema.SkipSoftDelete(ctx)
_, err := c.client.Passkey.Delete().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).Exec(ctx)
return err
}
func (c *userClient) MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error {
_, err := c.client.Passkey.Update().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).SetUsedAt(time.Now()).Save(ctx)
return err
}
func (c *userClient) Delete(ctx context.Context, uid int) error {
// Dav accounts
if _, err := c.client.DavAccount.Delete().Where(davaccount.OwnerID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil {
return fmt.Errorf("failed to delete dav accounts: %w", err)
}
// Passkeys
if _, err := c.client.Passkey.Delete().Where(passkey.UserID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil {
return fmt.Errorf("failed to delete passkeys: %w", err)
}
// Tasks
if _, err := c.client.Task.Delete().Where(task.UserTasks(uid)).Exec(ctx); err != nil {
return fmt.Errorf("failed to delete tasks: %w", err)
}
return c.client.User.DeleteOneID(uid).Exec(schema.SkipSoftDelete(ctx))
}
func (c *userClient) ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error {
ae := serializer.NewAggregateError()
for uid, diff := range diffs {
if err := c.client.User.Update().Where(user.ID(uid)).AddStorage(diff).Exec(ctx); err != nil {
ae.Add(fmt.Sprintf("%d", uid), fmt.Errorf("failed to apply storage diff for user %d: %w", uid, err))
}
}
return ae.Aggregate()
}
func (c *userClient) CalculateStorage(ctx context.Context, uid int) (int64, error) {
var sum int64
batchSize := 5000
offset := 0
for {
allFiles, err := c.client.File.Query().
Where(file.HasOwnerWith(user.ID(uid))).
WithEntities().
Offset(offset).
Limit(batchSize).
All(ctx)
if err != nil {
return 0, fmt.Errorf("failed to list user files: %w", err)
}
if len(allFiles) == 0 {
break
}
for _, file := range allFiles {
for _, entity := range file.Edges.Entities {
sum += entity.Size
}
}
offset += batchSize
}
if _, err := c.client.User.UpdateOneID(uid).SetStorage(sum).Save(ctx); err != nil {
return 0, err
}
return sum, nil
}
func (c *userClient) SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error) {
return c.client.User.UpdateOne(u).SetStatus(status).Save(ctx)
}
func (c *userClient) Create(ctx context.Context, args *NewUserArgs) (*ent.User, error) {
// Try to check if there's user with same email.
if existedUser, err := c.GetByEmail(ctx, args.Email); err == nil {
if existedUser.Status == user.StatusInactive {
return existedUser, ErrInactiveUserExisted
}
return existedUser, ErrUserEmailExisted
}
nick := args.Nick
if nick == "" {
nick = strings.Split(args.Email, "@")[0]
}
userSetting := &types.UserSetting{VersionRetention: true, VersionRetentionMax: 10}
query := c.client.User.Create().
SetEmail(args.Email).
SetNick(nick).
SetStatus(args.Status).
SetGroupID(args.GroupID).
SetAvatar(args.Avatar)
if args.PlainPassword != "" {
pwdDigest, err := digestPassword(args.PlainPassword)
if err != nil {
return nil, fmt.Errorf("failed to sha256 password: %w", err)
}
query.SetPassword(pwdDigest)
}
if args.Language != "" {
userSetting.Language = args.Language
}
query.SetSettings(userSetting)
// Create user
newUser, err := query.
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
if newUser.ID == 1 {
// For the first user registered, elevate it to admin group.
if _, err := newUser.Update().SetGroupID(1).Save(ctx); err != nil {
return newUser, fmt.Errorf("failed to elevate user to admin: %w", err)
}
}
return newUser, nil
}
func (c *userClient) GetByEmail(ctx context.Context, email string) (*ent.User, error) {
return withUserEagerLoading(ctx, c.client.User.Query().Where(user.EmailEqualFold(email))).First(ctx)
}
func (c *userClient) GetByID(ctx context.Context, id int) (*ent.User, error) {
return withUserEagerLoading(ctx, c.client.User.Query().Where(user.ID(id))).First(ctx)
}
func (c *userClient) GetActiveByID(ctx context.Context, id int) (*ent.User, error) {
return withUserEagerLoading(
ctx,
c.client.User.Query().
Where(user.ID(id)).
Where(user.StatusEQ(user.StatusActive)),
).First(ctx)
}
func (c *userClient) GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error) {
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
return withUserEagerLoading(
ctx,
c.client.User.Query().
Where(user.EmailEqualFold(email)).
Where(user.StatusEQ(user.StatusActive)).
WithDavAccounts(func(q *ent.DavAccountQuery) {
q.Where(davaccount.Password(pwd))
}),
).First(ctx)
}
func (c *userClient) GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error) {
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
if uid > 0 {
expectedUser, err := c.GetActiveByID(ctx, uid)
if err == nil {
return expectedUser, nil
}
return nil, fmt.Errorf("failed to get user by id: %w", err)
}
anonymous, err := c.AnonymousUser(ctx)
if err != nil {
return nil, fmt.Errorf("failed to construct anonymous user: %w", err)
}
return anonymous, nil
}
func (c *userClient) SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error) {
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
return withUserEagerLoading(
ctx,
c.client.User.Query().
Where(user.Or(user.EmailContainsFold(keyword), user.NickContainsFold(keyword))).
Limit(limit),
).All(ctx)
}
func (c *userClient) SaveSettings(ctx context.Context, u *ent.User) error {
return c.client.User.UpdateOne(u).SetSettings(u.Settings).Exec(ctx)
}
// UserFromContext get user from context
func UserFromContext(ctx context.Context) *ent.User {
u, _ := ctx.Value(UserCtx{}).(*ent.User)
return u
}
// UserIDFromContext get user id from context.
func UserIDFromContext(ctx context.Context) int {
uid, ok := ctx.Value(UserIDCtx{}).(int)
if !ok {
u := UserFromContext(ctx)
if u != nil {
uid = u.ID
}
}
return uid
}
func (c *userClient) AnonymousUser(ctx context.Context) (*ent.User, error) {
groupClient := NewGroupClient(c.client, "", nil)
anonymousGroup, err := groupClient.AnonymousGroup(ctx)
if err != nil {
return nil, fmt.Errorf("anyonymous group not found: %w", err)
}
// TODO: save into cache
anonymous := &ent.User{
Settings: &types.UserSetting{},
}
anonymous.SetGroup(anonymousGroup)
return anonymous, nil
}
func (c *userClient) ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error) {
query := c.client.User.Query()
if args.GroupID != 0 {
query = query.Where(user.GroupUsers(args.GroupID))
}
if args.Status != "" {
query = query.Where(user.StatusEQ(args.Status))
}
if args.Nick != "" {
query = query.Where(user.NickContainsFold(args.Nick))
}
if args.Email != "" {
query = query.Where(user.EmailContainsFold(args.Email))
}
query.Order(getUserOrderOption(args)...)
// Count total items
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, err
}
users, err := withUserEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, err
}
return &ListUserResult{
PaginationResults: &PaginationResults{
TotalItems: total,
Page: args.Page,
PageSize: args.PageSize,
},
Users: users,
}, nil
}
func (c *userClient) Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error) {
if u.ID == 0 {
q := c.client.User.Create().
SetEmail(u.Email).
SetNick(u.Nick).
SetAvatar(u.Avatar).
SetStatus(u.Status).
SetGroupID(u.GroupUsers).
SetPassword(u.Password).
SetSettings(&types.UserSetting{})
if password != "" {
pwdDigest, err := digestPassword(password)
if err != nil {
return nil, fmt.Errorf("failed to sha256 password: %w", err)
}
q.SetPassword(pwdDigest)
}
return q.Save(ctx)
}
q := c.client.User.UpdateOne(u).
SetEmail(u.Email).
SetNick(u.Nick).
SetAvatar(u.Avatar).
SetStatus(u.Status).
SetGroupID(u.GroupUsers)
if password != "" {
pwdDigest, err := digestPassword(password)
if err != nil {
return nil, fmt.Errorf("failed to sha256 password: %w", err)
}
q.SetPassword(pwdDigest)
}
if twoFa != "" {
q.ClearTwoFactorSecret()
}
return q.Save(ctx)
}
func getUserOrderOption(args *ListUserParameters) []user.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case user.FieldNick:
return []user.OrderOption{user.ByNick(orderTerm), user.ByID(orderTerm)}
case user.FieldStorage:
return []user.OrderOption{user.ByStorage(orderTerm), user.ByID(orderTerm)}
case user.FieldEmail:
return []user.OrderOption{user.ByEmail(orderTerm), user.ByID(orderTerm)}
case user.FieldUpdatedAt:
return []user.OrderOption{user.ByUpdatedAt(orderTerm), user.ByID(orderTerm)}
default:
return []user.OrderOption{user.ByID(orderTerm)}
}
}
// IsAnonymousUser check if given user is anonymous user.
func IsAnonymousUser(u *ent.User) bool {
return u.ID == 0
}
// CheckPassword 根据明文校验密码
func CheckPassword(u *ent.User, password string) error {
// 根据存储密码拆分为 Salt 和 Digest
passwordStore := strings.Split(u.Password, ":")
if len(passwordStore) != 2 && len(passwordStore) != 3 {
return ErrorUnknownPasswordType
}
// 兼容V2密码升级后存储格式为: md5:$HASH:$SALT
if len(passwordStore) == 3 {
if passwordStore[0] != "md5" {
return ErrorUnknownPasswordType
}
hash := md5.New()
_, err := hash.Write([]byte(passwordStore[2] + password))
bs := hex.EncodeToString(hash.Sum(nil))
if err != nil {
return err
}
if bs != passwordStore[1] {
return ErrorIncorrectPassword
}
}
//计算 Salt 和密码组合的SHA1摘要
var hasher hash.Hash
if len(passwordStore[1]) == 64 {
hasher = sha256.New()
} else {
// Compatible with V3
hasher = sha1.New()
}
_, err := hasher.Write([]byte(password + passwordStore[0]))
bs := hex.EncodeToString(hasher.Sum(nil))
if err != nil {
return err
}
if bs != passwordStore[1] {
return ErrorIncorrectPassword
}
return nil
}
func withUserEagerLoading(ctx context.Context, q *ent.UserQuery) *ent.UserQuery {
if v, ok := ctx.Value(LoadUserGroup{}).(bool); ok && v {
q.WithGroup(func(gq *ent.GroupQuery) {
withGroupEagerLoading(ctx, gq)
})
}
if v, ok := ctx.Value(LoadUserPasskey{}).(bool); ok && v {
q.WithPasskey()
}
return q
}
func digestPassword(password string) (string, error) {
//生成16位 Salt
salt := util.RandStringRunes(16)
//计算 Salt 和密码组合的SHA1摘要
hash := sha256.New()
_, err := hash.Write([]byte(password + salt))
bs := hex.EncodeToString(hash.Sum(nil))
if err != nil {
return "", err
}
//存储 Salt 值和摘要, ":"分割
return salt + ":" + string(bs), nil
}