Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
454
inventory/client.go
Normal file
454
inventory/client.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
rawsql "database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/group"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/node"
|
||||
_ "github.com/cloudreve/Cloudreve/v4/ent/runtime"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/storagepolicy"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/debug"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
DBVersionPrefix = "db_version_"
|
||||
EnvDefaultOverwritePrefix = "CR_SETTING_DEFAULT_"
|
||||
EnvEnableAria2 = "CR_ENABLE_ARIA2"
|
||||
)
|
||||
|
||||
// InitializeDBClient runs migration and returns a new ent.Client with additional configurations
|
||||
// for hooks and interceptors.
|
||||
func InitializeDBClient(l logging.Logger,
|
||||
client *ent.Client, kv cache.Driver, requiredDbVersion string) (*ent.Client, error) {
|
||||
ctx := context.WithValue(context.Background(), logging.LoggerCtx{}, l)
|
||||
if needMigration(client, ctx, requiredDbVersion) {
|
||||
// Run the auto migration tool.
|
||||
if err := migrate(l, client, ctx, kv, requiredDbVersion); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
} else {
|
||||
l.Info("Database schema is up to date.")
|
||||
}
|
||||
|
||||
//createMockData(client, ctx)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// NewRawEntClient returns a new ent.Client without additional configurations.
|
||||
func NewRawEntClient(l logging.Logger, config conf.ConfigProvider) (*ent.Client, error) {
|
||||
l.Info("Initializing database connection...")
|
||||
dbConfig := config.Database()
|
||||
confDBType := dbConfig.Type
|
||||
if confDBType == conf.SQLite3DB || confDBType == "" {
|
||||
confDBType = conf.SQLiteDB
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
client *sql.Driver
|
||||
)
|
||||
|
||||
switch confDBType {
|
||||
case conf.SQLiteDB:
|
||||
dbFile := util.RelativePath(dbConfig.DBFile)
|
||||
l.Info("Connect to SQLite database %q.", dbFile)
|
||||
client, err = sql.Open("sqlite3", util.RelativePath(dbConfig.DBFile))
|
||||
case conf.PostgresDB:
|
||||
l.Info("Connect to Postgres database %q.", dbConfig.Host)
|
||||
client, err = sql.Open("postgres", fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
|
||||
dbConfig.Host,
|
||||
dbConfig.User,
|
||||
dbConfig.Password,
|
||||
dbConfig.Name,
|
||||
dbConfig.Port))
|
||||
case conf.MySqlDB, conf.MsSqlDB:
|
||||
l.Info("Connect to MySQL/SQLServer database %q.", dbConfig.Host)
|
||||
var host string
|
||||
if dbConfig.UnixSocket {
|
||||
host = fmt.Sprintf("unix(%s)",
|
||||
dbConfig.Host)
|
||||
} else {
|
||||
host = fmt.Sprintf("(%s:%d)",
|
||||
dbConfig.Host,
|
||||
dbConfig.Port)
|
||||
}
|
||||
|
||||
client, err = sql.Open(string(confDBType), fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local",
|
||||
dbConfig.User,
|
||||
dbConfig.Password,
|
||||
host,
|
||||
dbConfig.Name,
|
||||
dbConfig.Charset))
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type %q", confDBType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool
|
||||
db := client.DB()
|
||||
db.SetMaxIdleConns(50)
|
||||
if confDBType == "sqlite" || confDBType == "UNSET" {
|
||||
db.SetMaxOpenConns(1)
|
||||
} else {
|
||||
db.SetMaxOpenConns(100)
|
||||
}
|
||||
|
||||
// Set timeout
|
||||
db.SetConnMaxLifetime(time.Second * 30)
|
||||
|
||||
driverOpt := ent.Driver(client)
|
||||
|
||||
// Enable verbose logging for debug mode.
|
||||
if config.System().Debug {
|
||||
l.Debug("Debug mode is enabled for DB client.")
|
||||
driverOpt = ent.Driver(debug.DebugWithContext(client, func(ctx context.Context, i ...any) {
|
||||
logging.FromContext(ctx).Debug(i[0].(string), i[1:]...)
|
||||
}))
|
||||
}
|
||||
|
||||
return ent.NewClient(driverOpt), nil
|
||||
}
|
||||
|
||||
type sqlite3Driver struct {
|
||||
*sqlite.Driver
|
||||
}
|
||||
|
||||
type sqlite3DriverConn interface {
|
||||
Exec(string, []driver.Value) (driver.Result, error)
|
||||
}
|
||||
|
||||
func (d sqlite3Driver) Open(name string) (conn driver.Conn, err error) {
|
||||
conn, err = d.Driver.Open(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = conn.(sqlite3DriverConn).Exec("PRAGMA foreign_keys = ON;", nil)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func init() {
|
||||
rawsql.Register("sqlite3", sqlite3Driver{Driver: &sqlite.Driver{}})
|
||||
}
|
||||
|
||||
// needMigration exams if required schema version is satisfied.
|
||||
func needMigration(client *ent.Client, ctx context.Context, requiredDbVersion string) bool {
|
||||
c, _ := client.Setting.Query().Where(setting.NameEQ(DBVersionPrefix + requiredDbVersion)).Count(ctx)
|
||||
return c == 0
|
||||
}
|
||||
|
||||
func migrate(l logging.Logger, client *ent.Client, ctx context.Context, kv cache.Driver, requiredDbVersion string) error {
|
||||
l.Info("Start initializing database schema...")
|
||||
l.Info("Creating basic table schema...")
|
||||
if err := client.Schema.Create(ctx); err != nil {
|
||||
return fmt.Errorf("Failed creating schema resources: %w", err)
|
||||
}
|
||||
|
||||
migrateDefaultSettings(l, client, ctx, kv)
|
||||
|
||||
if err := migrateDefaultStoragePolicy(l, client, ctx); err != nil {
|
||||
return fmt.Errorf("failed migrating default storage policy: %w", err)
|
||||
}
|
||||
|
||||
if err := migrateSysGroups(l, client, ctx); err != nil {
|
||||
return fmt.Errorf("failed migrating default storage policy: %w", err)
|
||||
}
|
||||
|
||||
client.Setting.Create().SetName(DBVersionPrefix + requiredDbVersion).SetValue("installed").Save(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateDefaultSettings(l logging.Logger, client *ent.Client, ctx context.Context, kv cache.Driver) {
|
||||
// clean kv cache
|
||||
if err := kv.DeleteAll(); err != nil {
|
||||
l.Warning("Failed to remove all KV entries while schema migration: %s", err)
|
||||
}
|
||||
|
||||
// List existing settings into a map
|
||||
existingSettings := make(map[string]struct{})
|
||||
settings, err := client.Setting.Query().All(ctx)
|
||||
if err != nil {
|
||||
l.Warning("Failed to query existing settings: %s", err)
|
||||
}
|
||||
|
||||
for _, s := range settings {
|
||||
existingSettings[s.Name] = struct{}{}
|
||||
}
|
||||
|
||||
l.Info("Insert default settings...")
|
||||
for k, v := range DefaultSettings {
|
||||
if _, ok := existingSettings[k]; ok {
|
||||
l.Debug("Skip inserting setting %s, already exists.", k)
|
||||
continue
|
||||
}
|
||||
|
||||
if override, ok := os.LookupEnv(EnvDefaultOverwritePrefix + k); ok {
|
||||
l.Info("Override default setting %q with env value %q", k, override)
|
||||
v = override
|
||||
}
|
||||
|
||||
client.Setting.Create().SetName(k).SetValue(v).SaveX(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func migrateDefaultStoragePolicy(l logging.Logger, client *ent.Client, ctx context.Context) error {
|
||||
if _, err := client.StoragePolicy.Query().Where(storagepolicy.ID(1)).First(ctx); err == nil {
|
||||
l.Info("Default storage policy (ID=1) already exists, skip migrating.")
|
||||
return nil
|
||||
}
|
||||
|
||||
l.Info("Insert default storage policy...")
|
||||
if _, err := client.StoragePolicy.Create().
|
||||
SetName("Default storage policy").
|
||||
SetType(types.PolicyTypeLocal).
|
||||
SetDirNameRule(util.DataPath("uploads/{uid}/{path}")).
|
||||
SetFileNameRule("{uid}_{randomkey8}_{originname}").
|
||||
SetSettings(&types.PolicySetting{
|
||||
ChunkSize: 25 << 20, // 25MB
|
||||
PreAllocate: true,
|
||||
}).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed to create default storage policy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateSysGroups(l logging.Logger, client *ent.Client, ctx context.Context) error {
|
||||
if err := migrateAdminGroup(l, client, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := migrateUserGroup(l, client, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := migrateAnonymousGroup(l, client, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := migrateMasterNode(l, client, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateAdminGroup(l logging.Logger, client *ent.Client, ctx context.Context) error {
|
||||
if _, err := client.Group.Query().Where(group.ID(1)).First(ctx); err == nil {
|
||||
l.Info("Default admin group (ID=1) already exists, skip migrating.")
|
||||
return nil
|
||||
}
|
||||
|
||||
l.Info("Insert default admin group...")
|
||||
permissions := &boolset.BooleanSet{}
|
||||
boolset.Sets(map[types.GroupPermission]bool{
|
||||
types.GroupPermissionIsAdmin: true,
|
||||
types.GroupPermissionShare: true,
|
||||
types.GroupPermissionWebDAV: true,
|
||||
types.GroupPermissionWebDAVProxy: true,
|
||||
types.GroupPermissionArchiveDownload: true,
|
||||
types.GroupPermissionArchiveTask: true,
|
||||
types.GroupPermissionShareDownload: true,
|
||||
types.GroupPermissionRemoteDownload: true,
|
||||
types.GroupPermissionRedirectedSource: true,
|
||||
types.GroupPermissionAdvanceDelete: true,
|
||||
types.GroupPermissionIgnoreFileOwnership: true,
|
||||
// TODO: review default permission
|
||||
}, permissions)
|
||||
if _, err := client.Group.Create().
|
||||
SetName("Admin").
|
||||
SetStoragePoliciesID(1).
|
||||
SetMaxStorage(1 * constants.TB). // 1 TB default storage
|
||||
SetPermissions(permissions).
|
||||
SetSettings(&types.GroupSetting{
|
||||
SourceBatchSize: 1000,
|
||||
Aria2BatchSize: 50,
|
||||
MaxWalkedFiles: 100000,
|
||||
TrashRetention: 7 * 24 * 3600,
|
||||
RedirectedSource: true,
|
||||
}).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed to create default admin group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateUserGroup(l logging.Logger, client *ent.Client, ctx context.Context) error {
|
||||
if _, err := client.Group.Query().Where(group.ID(2)).First(ctx); err == nil {
|
||||
l.Info("Default user group (ID=2) already exists, skip migrating.")
|
||||
return nil
|
||||
}
|
||||
|
||||
l.Info("Insert default user group...")
|
||||
permissions := &boolset.BooleanSet{}
|
||||
boolset.Sets(map[types.GroupPermission]bool{
|
||||
types.GroupPermissionShare: true,
|
||||
types.GroupPermissionShareDownload: true,
|
||||
types.GroupPermissionRedirectedSource: true,
|
||||
}, permissions)
|
||||
if _, err := client.Group.Create().
|
||||
SetName("User").
|
||||
SetStoragePoliciesID(1).
|
||||
SetMaxStorage(1 * constants.GB). // 1 GB default storage
|
||||
SetPermissions(permissions).
|
||||
SetSettings(&types.GroupSetting{
|
||||
SourceBatchSize: 10,
|
||||
Aria2BatchSize: 1,
|
||||
MaxWalkedFiles: 100000,
|
||||
TrashRetention: 7 * 24 * 3600,
|
||||
RedirectedSource: true,
|
||||
}).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed to create default user group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateAnonymousGroup(l logging.Logger, client *ent.Client, ctx context.Context) error {
|
||||
if _, err := client.Group.Query().Where(group.ID(AnonymousGroupID)).First(ctx); err == nil {
|
||||
l.Info("Default anonymous group (ID=3) already exists, skip migrating.")
|
||||
return nil
|
||||
}
|
||||
|
||||
l.Info("Insert default anonymous group...")
|
||||
permissions := &boolset.BooleanSet{}
|
||||
boolset.Sets(map[types.GroupPermission]bool{
|
||||
types.GroupPermissionIsAnonymous: true,
|
||||
types.GroupPermissionShareDownload: true,
|
||||
}, permissions)
|
||||
if _, err := client.Group.Create().
|
||||
SetName("Anonymous").
|
||||
SetPermissions(permissions).
|
||||
SetSettings(&types.GroupSetting{
|
||||
MaxWalkedFiles: 100000,
|
||||
RedirectedSource: true,
|
||||
}).
|
||||
Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed to create default anonymous group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateMasterNode(l logging.Logger, client *ent.Client, ctx context.Context) error {
|
||||
if _, err := client.Node.Query().Where(node.TypeEQ(node.TypeMaster)).First(ctx); err == nil {
|
||||
l.Info("Default master node already exists, skip migrating.")
|
||||
return nil
|
||||
}
|
||||
|
||||
capabilities := &boolset.BooleanSet{}
|
||||
boolset.Sets(map[types.NodeCapability]bool{
|
||||
types.NodeCapabilityCreateArchive: true,
|
||||
types.NodeCapabilityExtractArchive: true,
|
||||
types.NodeCapabilityRemoteDownload: true,
|
||||
}, capabilities)
|
||||
|
||||
stm := client.Node.Create().
|
||||
SetType(node.TypeMaster).
|
||||
SetCapabilities(capabilities).
|
||||
SetName("Master").
|
||||
SetSettings(&types.NodeSetting{
|
||||
Provider: types.DownloaderProviderAria2,
|
||||
}).
|
||||
SetStatus(node.StatusActive)
|
||||
|
||||
_, enableAria2 := os.LookupEnv(EnvEnableAria2)
|
||||
if enableAria2 {
|
||||
l.Info("Aria2 is override as enabled.")
|
||||
stm.SetSettings(&types.NodeSetting{
|
||||
Provider: types.DownloaderProviderAria2,
|
||||
Aria2Setting: &types.Aria2Setting{
|
||||
Server: "http://127.0.0.1:6800/jsonrpc",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
l.Info("Insert default master node...")
|
||||
if _, err := stm.Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed to create default master node: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createMockData(client *ent.Client, ctx context.Context) {
|
||||
//userCount := 100
|
||||
//folderCount := 10000
|
||||
//fileCount := 25000
|
||||
//
|
||||
//// create users
|
||||
//pwdDigest, _ := digestPassword("52121225")
|
||||
//userCreates := make([]*ent.UserCreate, userCount)
|
||||
//for i := 0; i < userCount; i++ {
|
||||
// nick := uuid.Must(uuid.NewV4()).String()
|
||||
// userCreates[i] = client.User.Create().
|
||||
// SetEmail(nick + "@cloudreve.org").
|
||||
// SetNick(nick).
|
||||
// SetPassword(pwdDigest).
|
||||
// SetStatus(user.StatusActive).
|
||||
// SetGroupID(1)
|
||||
//}
|
||||
//users, err := client.User.CreateBulk(userCreates...).Save(ctx)
|
||||
//if err != nil {
|
||||
// panic(err)
|
||||
//}
|
||||
//
|
||||
//// Create root folder
|
||||
//rootFolderCreates := make([]*ent.FileCreate, userCount)
|
||||
//folderIds := make([][]int, 0, folderCount*userCount+userCount)
|
||||
//for i, user := range users {
|
||||
// rootFolderCreates[i] = client.File.Create().
|
||||
// SetName(RootFolderName).
|
||||
// SetOwnerID(user.ID).
|
||||
// SetType(int(FileTypeFolder))
|
||||
//}
|
||||
//rootFolders, err := client.File.CreateBulk(rootFolderCreates...).Save(ctx)
|
||||
//for _, rootFolders := range rootFolders {
|
||||
// folderIds = append(folderIds, []int{rootFolders.ID, rootFolders.OwnerID})
|
||||
//}
|
||||
//if err != nil {
|
||||
// panic(err)
|
||||
//}
|
||||
//
|
||||
//// create random folder
|
||||
//for i := 0; i < folderCount*userCount; i++ {
|
||||
// parent := lo.Sample(folderIds)
|
||||
// res := client.File.Create().
|
||||
// SetName(uuid.Must(uuid.NewV4()).String()).
|
||||
// SetType(int(FileTypeFolder)).
|
||||
// SetOwnerID(parent[1]).
|
||||
// SetFileChildren(parent[0]).
|
||||
// SaveX(ctx)
|
||||
// folderIds = append(folderIds, []int{res.ID, res.OwnerID})
|
||||
//}
|
||||
|
||||
for i := 0; i < 255; i++ {
|
||||
fmt.Printf("%d/", i)
|
||||
}
|
||||
}
|
||||
124
inventory/common.go
Normal file
124
inventory/common.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
OrderDirection string
|
||||
PaginationArgs struct {
|
||||
UseCursorPagination bool
|
||||
Page int
|
||||
PageToken string
|
||||
PageSize int
|
||||
OrderBy string
|
||||
Order OrderDirection
|
||||
}
|
||||
|
||||
PaginationResults struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalItems int `json:"total_items,omitempty"`
|
||||
NextPageToken string `json:"next_token,omitempty"`
|
||||
IsCursor bool `json:"is_cursor,omitempty"`
|
||||
}
|
||||
|
||||
PageToken struct {
|
||||
Time *time.Time `json:"time,omitempty"`
|
||||
ID int `json:"-"`
|
||||
IDHash string `json:"id,omitempty"`
|
||||
String string `json:"string,omitempty"`
|
||||
Int int `json:"int,omitempty"`
|
||||
StartWithFile bool `json:"start_with_file,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
OrderDirectionAsc = OrderDirection("asc")
|
||||
OrderDirectionDesc = OrderDirection("desc")
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTooManyArguments = fmt.Errorf("too many arguments")
|
||||
)
|
||||
|
||||
func pageTokenFromString(s string, hasher hashid.Encoder, idType int) (*PageToken, error) {
|
||||
sB64Decoded, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 for page token: %w", err)
|
||||
}
|
||||
|
||||
token := &PageToken{}
|
||||
if err := json.Unmarshal(sB64Decoded, token); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal page token: %w", err)
|
||||
}
|
||||
|
||||
id, err := hasher.Decode(token.IDHash, idType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode id: %w", err)
|
||||
}
|
||||
|
||||
if token.Time == nil {
|
||||
token.Time = &time.Time{}
|
||||
}
|
||||
|
||||
token.ID = id
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (p *PageToken) Encode(hasher hashid.Encoder, encodeFunc hashid.EncodeFunc) (string, error) {
|
||||
p.IDHash = encodeFunc(hasher, p.ID)
|
||||
res, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal page token: %w", err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(res), nil
|
||||
}
|
||||
|
||||
// sqlParamLimit returns the max number of sql parameters.
|
||||
func sqlParamLimit(dbType conf.DBType) int {
|
||||
switch dbType {
|
||||
case conf.PostgresDB:
|
||||
return 34464
|
||||
case conf.SQLiteDB, conf.SQLite3DB:
|
||||
// https://www.sqlite.org/limits.html
|
||||
return 32766
|
||||
default:
|
||||
return 32766
|
||||
}
|
||||
}
|
||||
|
||||
// getOrderTerm returns the order term for ent.
|
||||
func getOrderTerm(d OrderDirection) sql.OrderTermOption {
|
||||
switch d {
|
||||
case OrderDirectionDesc:
|
||||
return sql.OrderDesc()
|
||||
default:
|
||||
return sql.OrderAsc()
|
||||
}
|
||||
}
|
||||
|
||||
func capPageSize(maxSQlParam, preferredSize, margin int) int {
|
||||
// Page size should not be bigger than max SQL parameter
|
||||
pageSize := preferredSize
|
||||
if maxSQlParam > 0 && pageSize > maxSQlParam-margin || pageSize == 0 {
|
||||
pageSize = maxSQlParam - margin
|
||||
}
|
||||
|
||||
return pageSize
|
||||
}
|
||||
|
||||
type StorageDiff map[int]int64
|
||||
|
||||
func (s *StorageDiff) Merge(diff StorageDiff) {
|
||||
for k, v := range diff {
|
||||
(*s)[k] += v
|
||||
}
|
||||
}
|
||||
187
inventory/dav_account.go
Normal file
187
inventory/dav_account.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/davaccount"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type (
|
||||
DavAccountClient interface {
|
||||
TxOperator
|
||||
// List returns a list of dav accounts with the given args.
|
||||
List(ctx context.Context, args *ListDavAccountArgs) (*ListDavAccountResult, error)
|
||||
// Create creates a new dav account.
|
||||
Create(ctx context.Context, params *CreateDavAccountParams) (*ent.DavAccount, error)
|
||||
// Update updates a dav account.
|
||||
Update(ctx context.Context, id int, params *CreateDavAccountParams) (*ent.DavAccount, error)
|
||||
// GetByIDAndUserID returns the dav account with given id and user id.
|
||||
GetByIDAndUserID(ctx context.Context, id, userID int) (*ent.DavAccount, error)
|
||||
// Delete deletes the dav account.
|
||||
Delete(ctx context.Context, id int) error
|
||||
}
|
||||
|
||||
ListDavAccountArgs struct {
|
||||
*PaginationArgs
|
||||
UserID int
|
||||
}
|
||||
ListDavAccountResult struct {
|
||||
*PaginationResults
|
||||
Accounts []*ent.DavAccount
|
||||
}
|
||||
CreateDavAccountParams struct {
|
||||
UserID int
|
||||
Name string
|
||||
URI string
|
||||
Password string
|
||||
Options *boolset.BooleanSet
|
||||
}
|
||||
)
|
||||
|
||||
func NewDavAccountClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) DavAccountClient {
|
||||
return &davAccountClient{
|
||||
client: client,
|
||||
hasher: hasher,
|
||||
maxSQlParam: sqlParamLimit(dbType),
|
||||
}
|
||||
}
|
||||
|
||||
type davAccountClient struct {
|
||||
maxSQlParam int
|
||||
client *ent.Client
|
||||
hasher hashid.Encoder
|
||||
}
|
||||
|
||||
func (c *davAccountClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &davAccountClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
|
||||
}
|
||||
|
||||
func (c *davAccountClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *davAccountClient) Create(ctx context.Context, params *CreateDavAccountParams) (*ent.DavAccount, error) {
|
||||
account := c.client.DavAccount.Create().
|
||||
SetOwnerID(params.UserID).
|
||||
SetName(params.Name).
|
||||
SetURI(params.URI).
|
||||
SetPassword(params.Password).
|
||||
SetOptions(params.Options)
|
||||
|
||||
return account.Save(ctx)
|
||||
}
|
||||
|
||||
func (c *davAccountClient) GetByIDAndUserID(ctx context.Context, id, userID int) (*ent.DavAccount, error) {
|
||||
return c.client.DavAccount.Query().
|
||||
Where(davaccount.ID(id), davaccount.OwnerID(userID)).
|
||||
First(ctx)
|
||||
}
|
||||
|
||||
func (c *davAccountClient) Update(ctx context.Context, id int, params *CreateDavAccountParams) (*ent.DavAccount, error) {
|
||||
account := c.client.DavAccount.UpdateOneID(id).
|
||||
SetName(params.Name).
|
||||
SetURI(params.URI).
|
||||
SetOptions(params.Options)
|
||||
|
||||
return account.Save(ctx)
|
||||
}
|
||||
|
||||
func (c *davAccountClient) Delete(ctx context.Context, id int) error {
|
||||
return c.client.DavAccount.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
func (c *davAccountClient) List(ctx context.Context, args *ListDavAccountArgs) (*ListDavAccountResult, error) {
|
||||
query := c.listQuery(args)
|
||||
|
||||
var (
|
||||
accounts []*ent.DavAccount
|
||||
err error
|
||||
paginationRes *PaginationResults
|
||||
)
|
||||
accounts, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed with paginiation: %w", err)
|
||||
}
|
||||
|
||||
return &ListDavAccountResult{
|
||||
Accounts: accounts,
|
||||
PaginationResults: paginationRes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *davAccountClient) cursorPagination(ctx context.Context, query *ent.DavAccountQuery, args *ListDavAccountArgs, paramMargin int) ([]*ent.DavAccount, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(davaccount.ByID(sql.OrderDesc()))
|
||||
|
||||
var (
|
||||
pageToken *PageToken
|
||||
err error
|
||||
)
|
||||
if args.PageToken != "" {
|
||||
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.DavAccountID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
||||
}
|
||||
}
|
||||
queryPaged := getDavAccountCursorQuery(pageToken, query)
|
||||
|
||||
// Use page size + 1 to determine if there are more items to come
|
||||
queryPaged.Limit(pageSize + 1)
|
||||
|
||||
logs, err := queryPaged.
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// More items to come
|
||||
nextTokenStr := ""
|
||||
if len(logs) > pageSize {
|
||||
lastItem := logs[len(logs)-2]
|
||||
nextToken, err := getDavAccountNextPageToken(c.hasher, lastItem)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
||||
}
|
||||
|
||||
nextTokenStr = nextToken
|
||||
}
|
||||
|
||||
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
NextPageToken: nextTokenStr,
|
||||
IsCursor: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *davAccountClient) listQuery(args *ListDavAccountArgs) *ent.DavAccountQuery {
|
||||
query := c.client.DavAccount.Query()
|
||||
if args.UserID > 0 {
|
||||
query.Where(davaccount.OwnerID(args.UserID))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// getDavAccountNextPageToken returns the next page token for the given last dav account.
|
||||
func getDavAccountNextPageToken(hasher hashid.Encoder, last *ent.DavAccount) (string, error) {
|
||||
token := &PageToken{
|
||||
ID: last.ID,
|
||||
}
|
||||
|
||||
return token.Encode(hasher, hashid.EncodeDavAccountID)
|
||||
}
|
||||
|
||||
func getDavAccountCursorQuery(token *PageToken, query *ent.DavAccountQuery) *ent.DavAccountQuery {
|
||||
if token != nil {
|
||||
query.Where(davaccount.IDLT(token.ID))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
188
inventory/debug/debug.go
Normal file
188
inventory/debug/debug.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
const strMaxLen = 102400
|
||||
|
||||
type SkipDbLogging struct{}
|
||||
|
||||
// DebugDriver is a driver that logs all driver operations.
|
||||
type DebugDriver struct {
|
||||
dialect.Driver // underlying driver.
|
||||
log func(context.Context, ...any) // log function. defaults to log.Println.
|
||||
}
|
||||
|
||||
// DebugWithContext gets a driver and a logging function, and returns
|
||||
// a new debugged-driver that prints all outgoing operations with context.
|
||||
func DebugWithContext(d dialect.Driver, logger func(context.Context, ...any)) dialect.Driver {
|
||||
drv := &DebugDriver{d, logger}
|
||||
return drv
|
||||
}
|
||||
|
||||
// Exec logs its params and calls the underlying driver Exec method.
|
||||
func (d *DebugDriver) Exec(ctx context.Context, query string, args, v any) error {
|
||||
start := time.Now()
|
||||
err := d.Driver.Exec(ctx, query, args, v)
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return err
|
||||
}
|
||||
|
||||
d.log(ctx, fmt.Sprintf("driver.Exec: query=%v args=%v time=%v", query, args, time.Since(start)))
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecContext logs its params and calls the underlying driver ExecContext method if it is supported.
|
||||
func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
drv, ok := d.Driver.(interface {
|
||||
ExecContext(context.Context, string, ...any) (sql.Result, error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Driver.ExecContext is not supported")
|
||||
}
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return drv.ExecContext(ctx, query, args...)
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args))
|
||||
return drv.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Query logs its params and calls the underlying driver Query method.
|
||||
func (d *DebugDriver) Query(ctx context.Context, query string, args, v any) error {
|
||||
start := time.Now()
|
||||
err := d.Driver.Query(ctx, query, args, v)
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return err
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v time=%v", query, args, time.Since(start)))
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryContext logs its params and calls the underlying driver QueryContext method if it is supported.
|
||||
func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
drv, ok := d.Driver.(interface {
|
||||
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Driver.QueryContext is not supported")
|
||||
}
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return drv.QueryContext(ctx, query, args...)
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args))
|
||||
return drv.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Tx adds an log-id for the transaction and calls the underlying driver Tx command.
|
||||
func (d *DebugDriver) Tx(ctx context.Context) (dialect.Tx, error) {
|
||||
tx, err := d.Driver.Tx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := uuid.New().String()
|
||||
d.log(ctx, fmt.Sprintf("driver.Tx(%s): started", id))
|
||||
return &DebugTx{tx, id, d.log, ctx}, nil
|
||||
}
|
||||
|
||||
// BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported.
|
||||
func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) {
|
||||
drv, ok := d.Driver.(interface {
|
||||
BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Driver.BeginTx is not supported")
|
||||
}
|
||||
tx, err := drv.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := uuid.New().String()
|
||||
d.log(ctx, fmt.Sprintf("driver.BeginTx(%s): started", id))
|
||||
return &DebugTx{tx, id, d.log, ctx}, nil
|
||||
}
|
||||
|
||||
// DebugTx is a transaction implementation that logs all transaction operations.
|
||||
type DebugTx struct {
|
||||
dialect.Tx // underlying transaction.
|
||||
id string // transaction logging id.
|
||||
log func(context.Context, ...any) // log function. defaults to fmt.Println.
|
||||
ctx context.Context // underlying transaction context.
|
||||
}
|
||||
|
||||
// Exec logs its params and calls the underlying transaction Exec method.
|
||||
func (d *DebugTx) Exec(ctx context.Context, query string, args, v any) error {
|
||||
start := time.Now()
|
||||
err := d.Tx.Exec(ctx, query, args, v)
|
||||
printArgs := args
|
||||
if argsArray, ok := args.([]interface{}); ok {
|
||||
for i, argVal := range argsArray {
|
||||
if argValStr, ok := argVal.(string); ok && len(argValStr) > strMaxLen {
|
||||
printArgs.([]interface{})[i] = argValStr[:strMaxLen] + "...[Truncated]..."
|
||||
}
|
||||
}
|
||||
}
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return err
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("Tx(%s).Exec: query=%v args=%v time=%v", d.id, query, args, time.Since(start)))
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported.
|
||||
func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
drv, ok := d.Tx.(interface {
|
||||
ExecContext(context.Context, string, ...any) (sql.Result, error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Tx.ExecContext is not supported")
|
||||
}
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return drv.ExecContext(ctx, query, args...)
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args))
|
||||
return drv.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Query logs its params and calls the underlying transaction Query method.
|
||||
func (d *DebugTx) Query(ctx context.Context, query string, args, v any) error {
|
||||
start := time.Now()
|
||||
err := d.Tx.Query(ctx, query, args, v)
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return err
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v time=%v", d.id, query, args, time.Since(start)))
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported.
|
||||
func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
drv, ok := d.Tx.(interface {
|
||||
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Tx.QueryContext is not supported")
|
||||
}
|
||||
if skip, ok := ctx.Value(SkipDbLogging{}).(bool); ok && skip {
|
||||
return drv.QueryContext(ctx, query, args...)
|
||||
}
|
||||
d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args))
|
||||
return drv.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Commit logs this step and calls the underlying transaction Commit method.
|
||||
func (d *DebugTx) Commit() error {
|
||||
d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id))
|
||||
return d.Tx.Commit()
|
||||
}
|
||||
|
||||
// Rollback logs this step and calls the underlying transaction Rollback method.
|
||||
func (d *DebugTx) Rollback() error {
|
||||
d.log(d.ctx, fmt.Sprintf("Tx(%s): rollbacked", d.id))
|
||||
return d.Tx.Rollback()
|
||||
}
|
||||
70
inventory/direct_link.go
Normal file
70
inventory/direct_link.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/directlink"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
)
|
||||
|
||||
type (
|
||||
DirectLinkClient interface {
|
||||
TxOperator
|
||||
// GetByNameID get direct link by name and id
|
||||
GetByNameID(ctx context.Context, id int, name string) (*ent.DirectLink, error)
|
||||
// GetByID get direct link by id
|
||||
GetByID(ctx context.Context, id int) (*ent.DirectLink, error)
|
||||
}
|
||||
LoadDirectLinkFile struct{}
|
||||
)
|
||||
|
||||
func NewDirectLinkClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) DirectLinkClient {
|
||||
return &directLinkClient{
|
||||
client: client,
|
||||
hasher: hasher,
|
||||
maxSQlParam: sqlParamLimit(dbType),
|
||||
}
|
||||
}
|
||||
|
||||
type directLinkClient struct {
|
||||
maxSQlParam int
|
||||
client *ent.Client
|
||||
hasher hashid.Encoder
|
||||
}
|
||||
|
||||
func (c *directLinkClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &directLinkClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
|
||||
}
|
||||
|
||||
func (c *directLinkClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (d *directLinkClient) GetByID(ctx context.Context, id int) (*ent.DirectLink, error) {
|
||||
return withDirectLinkEagerLoading(ctx, d.client.DirectLink.Query().Where(directlink.ID(id))).
|
||||
First(ctx)
|
||||
}
|
||||
|
||||
func (d *directLinkClient) GetByNameID(ctx context.Context, id int, name string) (*ent.DirectLink, error) {
|
||||
res, err := withDirectLinkEagerLoading(ctx, d.client.DirectLink.Query().Where(directlink.ID(id), directlink.Name(name))).
|
||||
First(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Increase download counter
|
||||
_, _ = d.client.DirectLink.Update().Where(directlink.ID(res.ID)).SetDownloads(res.Downloads + 1).Save(ctx)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func withDirectLinkEagerLoading(ctx context.Context, q *ent.DirectLinkQuery) *ent.DirectLinkQuery {
|
||||
if v, ok := ctx.Value(LoadDirectLinkFile{}).(bool); ok && v {
|
||||
q.WithFile(func(m *ent.FileQuery) {
|
||||
withFileEagerLoading(ctx, m)
|
||||
})
|
||||
}
|
||||
return q
|
||||
}
|
||||
1113
inventory/file.go
Normal file
1113
inventory/file.go
Normal file
File diff suppressed because it is too large
Load Diff
503
inventory/file_utils.go
Normal file
503
inventory/file_utils.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/entity"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/file"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/metadata"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func (f *fileClient) searchQuery(q *ent.FileQuery, args *SearchFileParameters, parents []*ent.File, ownerId int) *ent.FileQuery {
|
||||
if len(parents) == 1 && parents[0] == nil {
|
||||
q = q.Where(file.OwnerID(ownerId))
|
||||
} else {
|
||||
q = q.Where(
|
||||
file.HasParentWith(
|
||||
file.IDIn(lo.Map(parents, func(item *ent.File, index int) int {
|
||||
return item.ID
|
||||
})...,
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
if len(args.Name) > 0 {
|
||||
namePredicates := lo.Map(args.Name, func(item string, index int) predicate.File {
|
||||
// If start and ends with quotes, treat as exact match
|
||||
if strings.HasPrefix(item, "\"") && strings.HasSuffix(item, "\"") {
|
||||
return file.NameContains(strings.Trim(item, "\""))
|
||||
}
|
||||
|
||||
// if contain wildcard, use transform to sql like
|
||||
if strings.Contains(item, SearchWildcard) {
|
||||
pattern := strings.ReplaceAll(item, SearchWildcard, "%")
|
||||
if pattern[0] != '%' && pattern[len(pattern)-1] != '%' {
|
||||
// if not start with wildcard, add prefix wildcard
|
||||
pattern = "%" + pattern + "%"
|
||||
}
|
||||
|
||||
return func(s *sql.Selector) {
|
||||
s.Where(sql.Like(file.FieldName, pattern))
|
||||
}
|
||||
}
|
||||
|
||||
if args.CaseFolding {
|
||||
return file.NameContainsFold(item)
|
||||
}
|
||||
|
||||
return file.NameContains(item)
|
||||
})
|
||||
|
||||
if args.NameOperatorOr {
|
||||
q = q.Where(file.Or(namePredicates...))
|
||||
} else {
|
||||
q = q.Where(file.And(namePredicates...))
|
||||
}
|
||||
}
|
||||
|
||||
if args.Type != nil {
|
||||
q = q.Where(file.TypeEQ(int(*args.Type)))
|
||||
}
|
||||
|
||||
if len(args.Metadata) > 0 {
|
||||
metaPredicates := lo.MapToSlice(args.Metadata, func(name string, value string) predicate.Metadata {
|
||||
nameEq := metadata.NameEQ(value)
|
||||
if name == "" {
|
||||
return nameEq
|
||||
} else {
|
||||
valueContain := metadata.ValueContainsFold(value)
|
||||
return metadata.And(metadata.NameEQ(name), valueContain)
|
||||
}
|
||||
})
|
||||
metaPredicates = append(metaPredicates, metadata.IsPublic(true))
|
||||
q.Where(file.HasMetadataWith(metadata.And(metaPredicates...)))
|
||||
}
|
||||
|
||||
if args.SizeLte > 0 || args.SizeGte > 0 {
|
||||
q = q.Where(file.SizeGTE(args.SizeGte), file.SizeLTE(args.SizeLte))
|
||||
}
|
||||
|
||||
if args.CreatedAtLte != nil {
|
||||
q = q.Where(file.CreatedAtLTE(*args.CreatedAtLte))
|
||||
}
|
||||
|
||||
if args.CreatedAtGte != nil {
|
||||
q = q.Where(file.CreatedAtGTE(*args.CreatedAtGte))
|
||||
}
|
||||
|
||||
if args.UpdatedAtLte != nil {
|
||||
q = q.Where(file.UpdatedAtLTE(*args.UpdatedAtLte))
|
||||
}
|
||||
|
||||
if args.UpdatedAtGte != nil {
|
||||
q = q.Where(file.UpdatedAtGTE(*args.UpdatedAtGte))
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// ChildFileQuery generates query for child file(s) of a given set of root
|
||||
func (f *fileClient) childFileQuery(ownerID int, isSymbolic bool, root ...*ent.File) *ent.FileQuery {
|
||||
rawQuery := f.client.File.Query()
|
||||
if len(root) == 1 && root[0] != nil {
|
||||
// Query children of one single root
|
||||
rawQuery = f.client.File.QueryChildren(root[0])
|
||||
} else if root[0] == nil {
|
||||
// Query orphan files with owner ID
|
||||
predicates := []predicate.File{
|
||||
file.NameNEQ(RootFolderName),
|
||||
}
|
||||
|
||||
if ownerID > 0 {
|
||||
predicates = append(predicates, file.OwnerIDEQ(ownerID))
|
||||
}
|
||||
|
||||
if isSymbolic {
|
||||
predicates = append(predicates, file.And(file.IsSymbolic(true), file.FileChildrenNotNil()))
|
||||
} else {
|
||||
predicates = append(predicates, file.Not(file.HasParent()))
|
||||
}
|
||||
|
||||
rawQuery = f.client.File.Query().Where(
|
||||
file.And(predicates...),
|
||||
)
|
||||
} else {
|
||||
// Query children of multiple roots
|
||||
rawQuery.
|
||||
Where(
|
||||
file.HasParentWith(
|
||||
file.IDIn(lo.Map(root, func(item *ent.File, index int) int {
|
||||
return item.ID
|
||||
})...),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
return rawQuery
|
||||
}
|
||||
|
||||
// batchInCondition returns a list of predicates that divide original group into smaller ones
|
||||
// to bypass DB limitations.
|
||||
func (f *fileClient) batchInCondition(pageSize, margin int, multiply int, ids []int) ([]predicate.File, [][]int) {
|
||||
pageSize = capPageSize(f.maxSQlParam, pageSize, margin)
|
||||
chunks := lo.Chunk(ids, max(pageSize/multiply, 1))
|
||||
return lo.Map(chunks, func(item []int, index int) predicate.File {
|
||||
return file.IDIn(item...)
|
||||
}), chunks
|
||||
}
|
||||
|
||||
func (f *fileClient) batchInConditionMetadataName(pageSize, margin int, multiply int, keys []string) ([]predicate.Metadata, [][]string) {
|
||||
pageSize = capPageSize(f.maxSQlParam, pageSize, margin)
|
||||
chunks := lo.Chunk(keys, max(pageSize/multiply, 1))
|
||||
return lo.Map(chunks, func(item []string, index int) predicate.Metadata {
|
||||
return metadata.NameIn(item...)
|
||||
}), chunks
|
||||
}
|
||||
|
||||
func (f *fileClient) batchInConditionEntityID(pageSize, margin int, multiply int, keys []int) ([]predicate.Entity, [][]int) {
|
||||
pageSize = capPageSize(f.maxSQlParam, pageSize, margin)
|
||||
chunks := lo.Chunk(keys, max(pageSize/multiply, 1))
|
||||
return lo.Map(chunks, func(item []int, index int) predicate.Entity {
|
||||
return entity.IDIn(item...)
|
||||
}), chunks
|
||||
}
|
||||
|
||||
// cursorPagination perform pagination with cursor, which is faster than fast pagination, but less flexible.
|
||||
func (f *fileClient) cursorPagination(ctx context.Context, query *ent.FileQuery,
|
||||
args *ListFileParameters, paramMargin int) ([]*ent.File, *PaginationResults, error) {
|
||||
pageSize := capPageSize(f.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getFileOrderOption(args)...)
|
||||
currentPage := 0
|
||||
// Three types of query option
|
||||
queryPaged := []*ent.FileQuery{
|
||||
query.Clone().
|
||||
Where(file.TypeEQ(int(types.FileTypeFolder))),
|
||||
query.Clone().
|
||||
Where(file.TypeEQ(int(types.FileTypeFile))),
|
||||
query.Clone().
|
||||
Where(file.TypeIn(int(types.FileTypeFolder), int(types.FileTypeFile))),
|
||||
}
|
||||
|
||||
var (
|
||||
pageToken *PageToken
|
||||
err error
|
||||
)
|
||||
if args.PageToken != "" {
|
||||
pageToken, err = pageTokenFromString(args.PageToken, f.hasher, hashid.FileID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
||||
}
|
||||
}
|
||||
queryPaged = getFileCursorQuery(args, pageToken, queryPaged)
|
||||
|
||||
// Use page size + 1 to determine if there are more items to come
|
||||
queryPaged[0].Limit(pageSize + 1)
|
||||
|
||||
files, err := queryPaged[0].
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nextStartWithFile := false
|
||||
if pageToken != nil && pageToken.StartWithFile {
|
||||
nextStartWithFile = true
|
||||
}
|
||||
if len(files) < pageSize+1 && len(queryPaged) > 1 && !args.MixedType && !args.FolderOnly {
|
||||
queryPaged[1].Limit(pageSize + 1 - len(files))
|
||||
filesContinue, err := queryPaged[1].
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nextStartWithFile = true
|
||||
files = append(files, filesContinue...)
|
||||
}
|
||||
|
||||
// More items to come
|
||||
nextTokenStr := ""
|
||||
if len(files) > pageSize {
|
||||
lastItem := files[len(files)-2]
|
||||
nextToken, err := getFileNextPageToken(f.hasher, lastItem, args, nextStartWithFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
||||
}
|
||||
|
||||
nextTokenStr = nextToken
|
||||
}
|
||||
|
||||
return lo.Subset(files, 0, uint(pageSize)), &PaginationResults{
|
||||
Page: currentPage,
|
||||
PageSize: pageSize,
|
||||
NextPageToken: nextTokenStr,
|
||||
IsCursor: true,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// offsetPagination perform traditional pagination with minor optimizations.
|
||||
func (f *fileClient) offsetPagination(ctx context.Context, query *ent.FileQuery,
|
||||
args *ListFileParameters, paramMargin int) ([]*ent.File, *PaginationResults, error) {
|
||||
pageSize := capPageSize(f.maxSQlParam, args.PageSize, paramMargin)
|
||||
queryWithoutOrder := query.Clone()
|
||||
query.Order(getFileOrderOption(args)...)
|
||||
|
||||
// Count total items by type
|
||||
var v []struct {
|
||||
Type int `json:"type"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
err := queryWithoutOrder.Clone().
|
||||
GroupBy(file.FieldType).
|
||||
Aggregate(ent.Count()).
|
||||
Scan(ctx, &v)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
folderCount := 0
|
||||
fileCount := 0
|
||||
for _, item := range v {
|
||||
if item.Type == int(types.FileTypeFolder) {
|
||||
folderCount = item.Count
|
||||
} else {
|
||||
fileCount = item.Count
|
||||
}
|
||||
}
|
||||
|
||||
allFiles := make([]*ent.File, 0, pageSize)
|
||||
folderLimit := 0
|
||||
if (args.Page+1)*pageSize > folderCount {
|
||||
folderLimit = folderCount - args.Page*pageSize
|
||||
if folderLimit < 0 {
|
||||
folderLimit = 0
|
||||
}
|
||||
} else {
|
||||
folderLimit = pageSize
|
||||
}
|
||||
|
||||
if folderLimit <= pageSize && folderLimit > 0 {
|
||||
// Folder still remains
|
||||
folders, err := query.Clone().
|
||||
Limit(folderLimit).
|
||||
Offset(args.Page * pageSize).
|
||||
Where(file.TypeEQ(int(types.FileTypeFolder))).All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
allFiles = append(allFiles, folders...)
|
||||
}
|
||||
|
||||
if folderLimit < pageSize {
|
||||
files, err := query.Clone().
|
||||
Limit(pageSize - folderLimit).
|
||||
Offset((args.Page * pageSize) + folderLimit - folderCount).
|
||||
Where(file.TypeEQ(int(types.FileTypeFile))).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
allFiles = append(allFiles, files...)
|
||||
}
|
||||
|
||||
return allFiles, &PaginationResults{
|
||||
TotalItems: folderCount + fileCount,
|
||||
Page: args.Page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func withFileEagerLoading(ctx context.Context, q *ent.FileQuery) *ent.FileQuery {
|
||||
if v, ok := ctx.Value(LoadFileEntity{}).(bool); ok && v {
|
||||
q.WithEntities(func(m *ent.EntityQuery) {
|
||||
m.Order(ent.Desc(entity.FieldID))
|
||||
withEntityEagerLoading(ctx, m)
|
||||
})
|
||||
}
|
||||
if v, ok := ctx.Value(LoadFileMetadata{}).(bool); ok && v {
|
||||
q.WithMetadata()
|
||||
}
|
||||
if v, ok := ctx.Value(LoadFilePublicMetadata{}).(bool); ok && v {
|
||||
q.WithMetadata(func(m *ent.MetadataQuery) {
|
||||
m.Where(metadata.IsPublic(true))
|
||||
})
|
||||
}
|
||||
if v, ok := ctx.Value(LoadFileShare{}).(bool); ok && v {
|
||||
q.WithShares()
|
||||
}
|
||||
if v, ok := ctx.Value(LoadFileUser{}).(bool); ok && v {
|
||||
q.WithOwner(func(query *ent.UserQuery) {
|
||||
withUserEagerLoading(ctx, query)
|
||||
})
|
||||
}
|
||||
if v, ok := ctx.Value(LoadFileDirectLink{}).(bool); ok && v {
|
||||
q.WithDirectLinks()
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
func withEntityEagerLoading(ctx context.Context, q *ent.EntityQuery) *ent.EntityQuery {
|
||||
if v, ok := ctx.Value(LoadEntityUser{}).(bool); ok && v {
|
||||
q.WithUser()
|
||||
}
|
||||
|
||||
if v, ok := ctx.Value(LoadEntityStoragePolicy{}).(bool); ok && v {
|
||||
q.WithStoragePolicy()
|
||||
}
|
||||
|
||||
if v, ok := ctx.Value(LoadEntityFile{}).(bool); ok && v {
|
||||
q.WithFile(func(fq *ent.FileQuery) {
|
||||
withFileEagerLoading(ctx, fq)
|
||||
})
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
func getFileOrderOption(args *ListFileParameters) []file.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case file.FieldName:
|
||||
return []file.OrderOption{file.ByName(orderTerm), file.ByID(orderTerm)}
|
||||
case file.FieldSize:
|
||||
return []file.OrderOption{file.BySize(orderTerm), file.ByID(orderTerm)}
|
||||
case file.FieldUpdatedAt:
|
||||
return []file.OrderOption{file.ByUpdatedAt(orderTerm), file.ByID(orderTerm)}
|
||||
default:
|
||||
return []file.OrderOption{file.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
func getEntityOrderOption(args *ListEntityParameters) []entity.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case entity.FieldSize:
|
||||
return []entity.OrderOption{entity.BySize(orderTerm), entity.ByID(orderTerm)}
|
||||
case entity.FieldUpdatedAt:
|
||||
return []entity.OrderOption{entity.ByUpdatedAt(orderTerm), entity.ByID(orderTerm)}
|
||||
case entity.FieldReferenceCount:
|
||||
return []entity.OrderOption{entity.ByReferenceCount(orderTerm), entity.ByID(orderTerm)}
|
||||
default:
|
||||
return []entity.OrderOption{entity.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
var fileCursorQuery = map[string]map[bool]func(token *PageToken) predicate.File{
|
||||
file.FieldName: {
|
||||
true: func(token *PageToken) predicate.File {
|
||||
return file.Or(
|
||||
file.NameLT(token.String),
|
||||
file.And(file.Name(token.String), file.IDLT(token.ID)),
|
||||
)
|
||||
},
|
||||
false: func(token *PageToken) predicate.File {
|
||||
return file.Or(
|
||||
file.NameGT(token.String),
|
||||
file.And(file.Name(token.String), file.IDGT(token.ID)),
|
||||
)
|
||||
},
|
||||
},
|
||||
file.FieldSize: {
|
||||
true: func(token *PageToken) predicate.File {
|
||||
return file.Or(
|
||||
file.SizeLT(int64(token.Int)),
|
||||
file.And(file.Size(int64(token.Int)), file.IDLT(token.ID)),
|
||||
)
|
||||
},
|
||||
false: func(token *PageToken) predicate.File {
|
||||
return file.Or(
|
||||
file.SizeGT(int64(token.Int)),
|
||||
file.And(file.Size(int64(token.Int)), file.IDGT(token.ID)),
|
||||
)
|
||||
},
|
||||
},
|
||||
file.FieldCreatedAt: {
|
||||
true: func(token *PageToken) predicate.File {
|
||||
return file.IDLT(token.ID)
|
||||
},
|
||||
false: func(token *PageToken) predicate.File {
|
||||
return file.IDGT(token.ID)
|
||||
},
|
||||
},
|
||||
file.FieldUpdatedAt: {
|
||||
true: func(token *PageToken) predicate.File {
|
||||
return file.Or(
|
||||
file.UpdatedAtLT(*token.Time),
|
||||
file.And(file.UpdatedAt(*token.Time), file.IDLT(token.ID)),
|
||||
)
|
||||
},
|
||||
false: func(token *PageToken) predicate.File {
|
||||
return file.Or(
|
||||
file.UpdatedAtGT(*token.Time),
|
||||
file.And(file.UpdatedAt(*token.Time), file.IDGT(token.ID)),
|
||||
)
|
||||
},
|
||||
},
|
||||
file.FieldID: {
|
||||
true: func(token *PageToken) predicate.File {
|
||||
return file.IDLT(token.ID)
|
||||
},
|
||||
false: func(token *PageToken) predicate.File {
|
||||
return file.IDGT(token.ID)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func getFileCursorQuery(args *ListFileParameters, token *PageToken, query []*ent.FileQuery) []*ent.FileQuery {
|
||||
o := &sql.OrderTermOptions{}
|
||||
getOrderTerm(args.Order)(o)
|
||||
|
||||
predicates, ok := fileCursorQuery[args.OrderBy]
|
||||
if !ok {
|
||||
predicates = fileCursorQuery[file.FieldID]
|
||||
}
|
||||
|
||||
// If all folder is already listed in previous page, only query for files.
|
||||
if token != nil && token.StartWithFile && !args.MixedType {
|
||||
query = query[1:2]
|
||||
}
|
||||
|
||||
// Mixing folders and files with one query
|
||||
if args.MixedType {
|
||||
query = query[2:]
|
||||
} else if args.FolderOnly {
|
||||
query = query[0:1]
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
query[0].Where(predicates[o.Desc](token))
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// getFileNextPageToken returns the next page token for the given last file.
|
||||
func getFileNextPageToken(hasher hashid.Encoder, last *ent.File, args *ListFileParameters, nextStartWithFile bool) (string, error) {
|
||||
token := &PageToken{
|
||||
ID: last.ID,
|
||||
StartWithFile: nextStartWithFile,
|
||||
}
|
||||
|
||||
switch args.OrderBy {
|
||||
case file.FieldName:
|
||||
token.String = last.Name
|
||||
case file.FieldSize:
|
||||
token.Int = int(last.Size)
|
||||
case file.FieldUpdatedAt:
|
||||
token.Time = &last.UpdatedAt
|
||||
}
|
||||
|
||||
return token.Encode(hasher, hashid.EncodeFileID)
|
||||
}
|
||||
170
inventory/group.go
Normal file
170
inventory/group.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/group"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ctx keys for eager loading options.
|
||||
LoadGroupPolicy struct{}
|
||||
)
|
||||
|
||||
const (
|
||||
AnonymousGroupID = 3
|
||||
)
|
||||
|
||||
type (
|
||||
GroupClient interface {
|
||||
TxOperator
|
||||
// AnonymousGroup returns the anonymous group.
|
||||
AnonymousGroup(ctx context.Context) (*ent.Group, error)
|
||||
// ListAll returns all groups.
|
||||
ListAll(ctx context.Context) ([]*ent.Group, error)
|
||||
// GetByID returns the group by id.
|
||||
GetByID(ctx context.Context, id int) (*ent.Group, error)
|
||||
// ListGroups returns a list of groups with pagination.
|
||||
ListGroups(ctx context.Context, args *ListGroupParameters) (*ListGroupResult, error)
|
||||
// CountUsers returns the number of users in the group.
|
||||
CountUsers(ctx context.Context, id int) (int, error)
|
||||
// Upsert upserts a group.
|
||||
Upsert(ctx context.Context, group *ent.Group) (*ent.Group, error)
|
||||
// Delete deletes a group.
|
||||
Delete(ctx context.Context, id int) error
|
||||
}
|
||||
ListGroupParameters struct {
|
||||
*PaginationArgs
|
||||
}
|
||||
ListGroupResult struct {
|
||||
*PaginationResults
|
||||
Groups []*ent.Group
|
||||
}
|
||||
)
|
||||
|
||||
func NewGroupClient(client *ent.Client, dbType conf.DBType, cache cache.Driver) GroupClient {
|
||||
return &groupClient{client: client, maxSQlParam: sqlParamLimit(dbType), cache: cache}
|
||||
}
|
||||
|
||||
type groupClient struct {
|
||||
client *ent.Client
|
||||
cache cache.Driver
|
||||
maxSQlParam int
|
||||
}
|
||||
|
||||
func (c *groupClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &groupClient{client: newClient, maxSQlParam: c.maxSQlParam, cache: c.cache}
|
||||
}
|
||||
|
||||
func (c *groupClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *groupClient) CountUsers(ctx context.Context, id int) (int, error) {
|
||||
return c.client.Group.Query().Where(group.ID(id)).QueryUsers().Count(ctx)
|
||||
}
|
||||
|
||||
func (c *groupClient) AnonymousGroup(ctx context.Context) (*ent.Group, error) {
|
||||
return withGroupEagerLoading(ctx, c.client.Group.Query().Where(group.ID(AnonymousGroupID))).First(ctx)
|
||||
}
|
||||
|
||||
func (c *groupClient) ListAll(ctx context.Context) ([]*ent.Group, error) {
|
||||
return withGroupEagerLoading(ctx, c.client.Group.Query()).All(ctx)
|
||||
}
|
||||
|
||||
func (c *groupClient) Upsert(ctx context.Context, group *ent.Group) (*ent.Group, error) {
|
||||
|
||||
if group.ID == 0 {
|
||||
return c.client.Group.Create().
|
||||
SetName(group.Name).
|
||||
SetMaxStorage(group.MaxStorage).
|
||||
SetSpeedLimit(group.SpeedLimit).
|
||||
SetPermissions(group.Permissions).
|
||||
SetSettings(group.Settings).
|
||||
SetStoragePoliciesID(group.Edges.StoragePolicies.ID).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
res, err := c.client.Group.UpdateOne(group).
|
||||
SetName(group.Name).
|
||||
SetMaxStorage(group.MaxStorage).
|
||||
SetSpeedLimit(group.SpeedLimit).
|
||||
SetPermissions(group.Permissions).
|
||||
SetSettings(group.Settings).
|
||||
ClearStoragePolicies().
|
||||
SetStoragePoliciesID(group.Edges.StoragePolicies.ID).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *groupClient) Delete(ctx context.Context, id int) error {
|
||||
if err := c.client.Group.DeleteOneID(id).Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *groupClient) ListGroups(ctx context.Context, args *ListGroupParameters) (*ListGroupResult, error) {
|
||||
query := withGroupEagerLoading(ctx, c.client.Group.Query())
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, 10)
|
||||
queryWithoutOrder := query.Clone()
|
||||
query.Order(getGroupOrderOption(args)...)
|
||||
|
||||
// Count total items
|
||||
total, err := queryWithoutOrder.Clone().
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups, err := query.Clone().
|
||||
Limit(pageSize).
|
||||
Offset(args.Page * pageSize).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListGroupResult{
|
||||
Groups: groups,
|
||||
PaginationResults: &PaginationResults{
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
PageSize: pageSize,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *groupClient) GetByID(ctx context.Context, id int) (*ent.Group, error) {
|
||||
return withGroupEagerLoading(ctx, c.client.Group.Query().Where(group.ID(id))).First(ctx)
|
||||
}
|
||||
|
||||
func getGroupOrderOption(args *ListGroupParameters) []group.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case group.FieldName:
|
||||
return []group.OrderOption{group.ByName(orderTerm), group.ByID(orderTerm)}
|
||||
case group.FieldMaxStorage:
|
||||
return []group.OrderOption{group.ByMaxStorage(orderTerm), group.ByID(orderTerm)}
|
||||
default:
|
||||
return []group.OrderOption{group.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
func withGroupEagerLoading(ctx context.Context, q *ent.GroupQuery) *ent.GroupQuery {
|
||||
if _, ok := ctx.Value(LoadGroupPolicy{}).(bool); ok {
|
||||
q.WithStoragePolicies(func(spq *ent.StoragePolicyQuery) {
|
||||
withStoragePolicyEagerLoading(ctx, spq)
|
||||
})
|
||||
}
|
||||
return q
|
||||
}
|
||||
156
inventory/node.go
Normal file
156
inventory/node.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/node"
|
||||
)
|
||||
|
||||
type (
|
||||
LoadNodeStoragePolicy struct{}
|
||||
NodeClient interface {
|
||||
TxOperator
|
||||
// ListActiveNodes returns the active nodes.
|
||||
ListActiveNodes(ctx context.Context, subset []int) ([]*ent.Node, error)
|
||||
// ListNodes returns the nodes with pagination.
|
||||
ListNodes(ctx context.Context, args *ListNodeParameters) (*ListNodeResult, error)
|
||||
// GetNodeById returns the node by id.
|
||||
GetNodeById(ctx context.Context, id int) (*ent.Node, error)
|
||||
// GetNodeByIds returns the nodes by ids.
|
||||
GetNodeByIds(ctx context.Context, ids []int) ([]*ent.Node, error)
|
||||
// Upsert upserts a node.
|
||||
Upsert(ctx context.Context, n *ent.Node) (*ent.Node, error)
|
||||
// Delete deletes a node.
|
||||
Delete(ctx context.Context, id int) error
|
||||
}
|
||||
ListNodeParameters struct {
|
||||
*PaginationArgs
|
||||
Status node.Status
|
||||
}
|
||||
ListNodeResult struct {
|
||||
*PaginationResults
|
||||
Nodes []*ent.Node
|
||||
}
|
||||
)
|
||||
|
||||
func NewNodeClient(client *ent.Client) NodeClient {
|
||||
return &nodeClient{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
type nodeClient struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func (c *nodeClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &nodeClient{client: newClient}
|
||||
}
|
||||
|
||||
func (c *nodeClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *nodeClient) ListActiveNodes(ctx context.Context, subset []int) ([]*ent.Node, error) {
|
||||
stm := c.client.Node.Query().Where(node.StatusEQ(node.StatusActive))
|
||||
if len(subset) > 0 {
|
||||
stm = stm.Where(node.IDIn(subset...))
|
||||
}
|
||||
return stm.All(ctx)
|
||||
}
|
||||
|
||||
func (c *nodeClient) GetNodeByIds(ctx context.Context, ids []int) ([]*ent.Node, error) {
|
||||
return withNodeEagerLoading(ctx, c.client.Node.Query().Where(node.IDIn(ids...))).All(ctx)
|
||||
}
|
||||
|
||||
func (c *nodeClient) GetNodeById(ctx context.Context, id int) (*ent.Node, error) {
|
||||
return withNodeEagerLoading(ctx, c.client.Node.Query().Where(node.IDEQ(id))).First(ctx)
|
||||
}
|
||||
|
||||
func (c *nodeClient) Delete(ctx context.Context, id int) error {
|
||||
return c.client.Node.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
func (c *nodeClient) ListNodes(ctx context.Context, args *ListNodeParameters) (*ListNodeResult, error) {
|
||||
query := c.client.Node.Query()
|
||||
if string(args.Status) != "" {
|
||||
query = query.Where(node.StatusEQ(args.Status))
|
||||
}
|
||||
query.Order(getNodeOrderOption(args)...)
|
||||
|
||||
// Count total items
|
||||
total, err := query.Clone().
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes, err := withNodeEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListNodeResult{
|
||||
PaginationResults: &PaginationResults{
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
PageSize: args.PageSize,
|
||||
},
|
||||
Nodes: nodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *nodeClient) Upsert(ctx context.Context, n *ent.Node) (*ent.Node, error) {
|
||||
if n.ID == 0 {
|
||||
return c.client.Node.Create().
|
||||
SetName(n.Name).
|
||||
SetServer(n.Server).
|
||||
SetSlaveKey(n.SlaveKey).
|
||||
SetStatus(n.Status).
|
||||
SetType(node.TypeSlave).
|
||||
SetSettings(n.Settings).
|
||||
SetCapabilities(n.Capabilities).
|
||||
SetWeight(n.Weight).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
res, err := c.client.Node.UpdateOne(n).
|
||||
SetName(n.Name).
|
||||
SetServer(n.Server).
|
||||
SetSlaveKey(n.SlaveKey).
|
||||
SetStatus(n.Status).
|
||||
SetSettings(n.Settings).
|
||||
SetCapabilities(n.Capabilities).
|
||||
SetWeight(n.Weight).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func getNodeOrderOption(args *ListNodeParameters) []node.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case node.FieldName:
|
||||
return []node.OrderOption{node.ByName(orderTerm), node.ByID(orderTerm)}
|
||||
case node.FieldWeight:
|
||||
return []node.OrderOption{node.ByWeight(orderTerm), node.ByID(orderTerm)}
|
||||
case node.FieldUpdatedAt:
|
||||
return []node.OrderOption{node.ByUpdatedAt(orderTerm), node.ByID(orderTerm)}
|
||||
default:
|
||||
return []node.OrderOption{node.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
func withNodeEagerLoading(ctx context.Context, query *ent.NodeQuery) *ent.NodeQuery {
|
||||
if _, ok := ctx.Value(LoadNodeStoragePolicy{}).(bool); ok {
|
||||
query = query.WithStoragePolicy(func(gq *ent.StoragePolicyQuery) {
|
||||
withStoragePolicyEagerLoading(ctx, gq)
|
||||
})
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
243
inventory/policy.go
Normal file
243
inventory/policy.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/storagepolicy"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
// StoragePolicyCacheKey is the cache key of storage policy.
|
||||
StoragePolicyCacheKey = "storage_policy_"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(ent.StoragePolicy{})
|
||||
gob.Register([]ent.StoragePolicy{})
|
||||
}
|
||||
|
||||
type (
|
||||
LoadStoragePolicyGroup struct{}
|
||||
SkipStoragePolicyCache struct{}
|
||||
|
||||
StoragePolicyClient interface {
|
||||
// GetByGroup returns the storage policies of the group.
|
||||
GetByGroup(ctx context.Context, group *ent.Group) (*ent.StoragePolicy, error)
|
||||
// GetPolicyByID returns the storage policy by id.
|
||||
GetPolicyByID(ctx context.Context, id int) (*ent.StoragePolicy, error)
|
||||
// UpdateAccessKey updates the access key of the storage policy. It also clear related cache in KV.
|
||||
UpdateAccessKey(ctx context.Context, policy *ent.StoragePolicy, token string) error
|
||||
// ListPolicyByType returns the storage policies by type.
|
||||
ListPolicyByType(ctx context.Context, t types.PolicyType) ([]*ent.StoragePolicy, error)
|
||||
// ListPolicies returns the storage policies with pagination.
|
||||
ListPolicies(ctx context.Context, args *ListPolicyParameters) (*ListPolicyResult, error)
|
||||
// Upsert upserts the storage policy.
|
||||
Upsert(ctx context.Context, policy *ent.StoragePolicy) (*ent.StoragePolicy, error)
|
||||
// Delete deletes the storage policy.
|
||||
Delete(ctx context.Context, policy *ent.StoragePolicy) error
|
||||
}
|
||||
|
||||
ListPolicyParameters struct {
|
||||
*PaginationArgs
|
||||
Type types.PolicyType
|
||||
}
|
||||
|
||||
ListPolicyResult struct {
|
||||
*PaginationResults
|
||||
Policies []*ent.StoragePolicy
|
||||
}
|
||||
)
|
||||
|
||||
// NewStoragePolicyClient returns a new StoragePolicyClient.
|
||||
func NewStoragePolicyClient(client *ent.Client, cache cache.Driver) StoragePolicyClient {
|
||||
return &storagePolicyClient{client: client, cache: cache}
|
||||
}
|
||||
|
||||
type storagePolicyClient struct {
|
||||
client *ent.Client
|
||||
cache cache.Driver
|
||||
}
|
||||
|
||||
func (c *storagePolicyClient) Delete(ctx context.Context, policy *ent.StoragePolicy) error {
|
||||
if err := c.client.StoragePolicy.DeleteOne(policy).Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete storage policy: %w", err)
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil {
|
||||
return fmt.Errorf("failed to clear storage policy cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *storagePolicyClient) Upsert(ctx context.Context, policy *ent.StoragePolicy) (*ent.StoragePolicy, error) {
|
||||
var nodeId *int
|
||||
if policy.NodeID != 0 {
|
||||
nodeId = &policy.NodeID
|
||||
}
|
||||
if policy.ID == 0 {
|
||||
p, err := c.client.StoragePolicy.
|
||||
Create().
|
||||
SetName(policy.Name).
|
||||
SetType(policy.Type).
|
||||
SetServer(policy.Server).
|
||||
SetBucketName(policy.BucketName).
|
||||
SetIsPrivate(policy.IsPrivate).
|
||||
SetAccessKey(policy.AccessKey).
|
||||
SetSecretKey(policy.SecretKey).
|
||||
SetMaxSize(policy.MaxSize).
|
||||
SetDirNameRule(policy.DirNameRule).
|
||||
SetFileNameRule(policy.FileNameRule).
|
||||
SetSettings(policy.Settings).
|
||||
SetNillableNodeID(nodeId).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage policy: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
updateQuery := c.client.StoragePolicy.UpdateOne(policy).
|
||||
SetName(policy.Name).
|
||||
SetType(policy.Type).
|
||||
SetServer(policy.Server).
|
||||
SetBucketName(policy.BucketName).
|
||||
SetIsPrivate(policy.IsPrivate).
|
||||
SetSecretKey(policy.SecretKey).
|
||||
SetMaxSize(policy.MaxSize).
|
||||
SetDirNameRule(policy.DirNameRule).
|
||||
SetFileNameRule(policy.FileNameRule).
|
||||
SetSettings(policy.Settings).
|
||||
SetNillableNodeID(nodeId)
|
||||
|
||||
if policy.Type != types.PolicyTypeOd {
|
||||
updateQuery.SetAccessKey(policy.AccessKey)
|
||||
}
|
||||
|
||||
p, err := updateQuery.Save(ctx)
|
||||
|
||||
// Clear cache
|
||||
if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil {
|
||||
return nil, fmt.Errorf("failed to clear storage policy cache: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update storage policy: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *storagePolicyClient) GetByGroup(ctx context.Context, group *ent.Group) (*ent.StoragePolicy, error) {
|
||||
val, skipCache := ctx.Value(SkipStoragePolicyCache{}).(bool)
|
||||
skipCache = skipCache && val
|
||||
|
||||
res, err := withStoragePolicyEagerLoading(ctx, c.client.Group.QueryStoragePolicies(group)).WithNode().First(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get storage policies: %w", err)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetPolicyByID returns the storage policy by id.
|
||||
func (c *storagePolicyClient) GetPolicyByID(ctx context.Context, id int) (*ent.StoragePolicy, error) {
|
||||
val, skipCache := ctx.Value(SkipStoragePolicyCache{}).(bool)
|
||||
skipCache = skipCache && val
|
||||
|
||||
// Try to read from cache
|
||||
if c.cache != nil && !skipCache {
|
||||
if res, ok := c.cache.Get(StoragePolicyCacheKey + strconv.Itoa(id)); ok {
|
||||
cached := res.(ent.StoragePolicy)
|
||||
return &cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
res, err := withStoragePolicyEagerLoading(ctx, c.client.StoragePolicy.Query().Where(storagepolicy.ID(id))).WithNode().First(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get storage policy: %w", err)
|
||||
}
|
||||
|
||||
// Write to cache
|
||||
if c.cache != nil && !skipCache {
|
||||
_ = c.cache.Set(StoragePolicyCacheKey+strconv.Itoa(id), *res, -1)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *storagePolicyClient) ListPolicyByType(ctx context.Context, t types.PolicyType) ([]*ent.StoragePolicy, error) {
|
||||
policies, err := c.client.StoragePolicy.Query().Where(storagepolicy.TypeEQ(string(t))).All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list storage policies: %w", err)
|
||||
}
|
||||
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
func (c *storagePolicyClient) UpdateAccessKey(ctx context.Context, policy *ent.StoragePolicy, token string) error {
|
||||
_, err := c.client.StoragePolicy.UpdateOne(policy).SetAccessKey(token).Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("faield to update access key in DB: %w", err)
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
if err := c.cache.Delete(StoragePolicyCacheKey, strconv.Itoa(policy.ID)); err != nil {
|
||||
return fmt.Errorf("failed to clear storage policy cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *storagePolicyClient) ListPolicies(ctx context.Context, args *ListPolicyParameters) (*ListPolicyResult, error) {
|
||||
query := c.client.StoragePolicy.Query().WithNode()
|
||||
if args.Type != "" {
|
||||
query = query.Where(storagepolicy.TypeEQ(string(args.Type)))
|
||||
}
|
||||
query.Order(getStoragePolicyOrderOption(args)...)
|
||||
|
||||
// Count total items
|
||||
total, err := query.Clone().
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policies, err := withStoragePolicyEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListPolicyResult{
|
||||
PaginationResults: &PaginationResults{
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
PageSize: args.PageSize,
|
||||
},
|
||||
Policies: policies,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getStoragePolicyOrderOption(args *ListPolicyParameters) []storagepolicy.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case storagepolicy.FieldUpdatedAt:
|
||||
return []storagepolicy.OrderOption{storagepolicy.ByUpdatedAt(orderTerm), storagepolicy.ByID(orderTerm)}
|
||||
default:
|
||||
return []storagepolicy.OrderOption{storagepolicy.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
func withStoragePolicyEagerLoading(ctx context.Context, query *ent.StoragePolicyQuery) *ent.StoragePolicyQuery {
|
||||
if _, ok := ctx.Value(LoadStoragePolicyGroup{}).(bool); ok {
|
||||
query = query.WithGroups(func(gq *ent.GroupQuery) {
|
||||
withGroupEagerLoading(ctx, gq)
|
||||
})
|
||||
}
|
||||
return query
|
||||
}
|
||||
248
inventory/setting.go
Normal file
248
inventory/setting.go
Normal file
File diff suppressed because one or more lines are too long
401
inventory/share.go
Normal file
401
inventory/share.go
Normal file
@@ -0,0 +1,401 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/file"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/share"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/user"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ctx keys for eager loading options.
|
||||
LoadShareFile struct{}
|
||||
LoadShareUser struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrShareLinkExpired = fmt.Errorf("share link expired")
|
||||
ErrOwnerInactive = fmt.Errorf("owner is inactive")
|
||||
ErrSourceFileInvalid = fmt.Errorf("source file is deleted")
|
||||
)
|
||||
|
||||
type (
|
||||
ShareClient interface {
|
||||
TxOperator
|
||||
// GetByID returns the share with given id.
|
||||
GetByID(ctx context.Context, id int) (*ent.Share, error)
|
||||
// GetByIDUser returns the share with given id and user id.
|
||||
GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error)
|
||||
// GetByHashID returns the share with given hash id.
|
||||
GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error)
|
||||
// Upsert creates or update a new share record.
|
||||
Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error)
|
||||
// Viewed increase the view count of the share.
|
||||
Viewed(ctx context.Context, share *ent.Share) error
|
||||
// Downloaded increase the download count of the share.
|
||||
Downloaded(ctx context.Context, share *ent.Share) error
|
||||
// Delete deletes the share.
|
||||
Delete(ctx context.Context, shareId int) error
|
||||
// List returns a list of shares with the given args.
|
||||
List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error)
|
||||
// CountByTimeRange counts the number of shares created in the given time range.
|
||||
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
|
||||
// DeleteBatch deletes the shares with the given ids.
|
||||
DeleteBatch(ctx context.Context, shareIds []int) error
|
||||
}
|
||||
|
||||
CreateShareParams struct {
|
||||
Existed *ent.Share
|
||||
Password string
|
||||
RemainDownloads int
|
||||
Expires *time.Time
|
||||
OwnerID int
|
||||
FileID int
|
||||
}
|
||||
|
||||
ListShareArgs struct {
|
||||
*PaginationArgs
|
||||
UserID int
|
||||
FileID int
|
||||
PublicOnly bool
|
||||
}
|
||||
ListShareResult struct {
|
||||
*PaginationResults
|
||||
Shares []*ent.Share
|
||||
}
|
||||
)
|
||||
|
||||
func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient {
|
||||
return &shareClient{
|
||||
client: client,
|
||||
hasher: hasher,
|
||||
maxSQlParam: sqlParamLimit(dbType),
|
||||
}
|
||||
}
|
||||
|
||||
type shareClient struct {
|
||||
maxSQlParam int
|
||||
client *ent.Client
|
||||
hasher hashid.Encoder
|
||||
}
|
||||
|
||||
func (c *shareClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
|
||||
}
|
||||
|
||||
func (c *shareClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
|
||||
if start == nil || end == nil {
|
||||
return c.client.Share.Query().Count(ctx)
|
||||
}
|
||||
|
||||
return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx)
|
||||
}
|
||||
|
||||
func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) {
|
||||
if params.Existed != nil {
|
||||
createQuery := c.client.Share.
|
||||
UpdateOne(params.Existed)
|
||||
if params.RemainDownloads > 0 {
|
||||
createQuery.SetRemainDownloads(params.RemainDownloads)
|
||||
} else {
|
||||
createQuery.ClearRemainDownloads()
|
||||
}
|
||||
if params.Expires != nil {
|
||||
createQuery.SetNillableExpires(params.Expires)
|
||||
} else {
|
||||
createQuery.ClearExpires()
|
||||
}
|
||||
|
||||
return createQuery.Save(ctx)
|
||||
}
|
||||
|
||||
query := c.client.Share.
|
||||
Create().
|
||||
SetUserID(params.OwnerID).
|
||||
SetFileID(params.FileID)
|
||||
if params.Password != "" {
|
||||
query.SetPassword(params.Password)
|
||||
}
|
||||
if params.RemainDownloads > 0 {
|
||||
query.SetRemainDownloads(params.RemainDownloads)
|
||||
}
|
||||
if params.Expires != nil {
|
||||
query.SetNillableExpires(params.Expires)
|
||||
}
|
||||
|
||||
return query.Save(ctx)
|
||||
}
|
||||
|
||||
func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) {
|
||||
id, err := c.hasher.Decode(idRaw, hashid.ShareID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err)
|
||||
}
|
||||
|
||||
return c.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) {
|
||||
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) {
|
||||
s, err := withShareEagerLoading(ctx, c.client.Share.Query().
|
||||
Where(share.ID(id))).
|
||||
Where(share.HasUserWith(user.ID(uid))).First(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error {
|
||||
_, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *shareClient) Delete(ctx context.Context, shareId int) error {
|
||||
return c.client.Share.DeleteOneID(shareId).Exec(ctx)
|
||||
}
|
||||
|
||||
// Viewed increments the view count of the share.
|
||||
func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error {
|
||||
_, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Downloaded increments the download count of the share.
|
||||
func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error {
|
||||
stm := c.client.Share.
|
||||
UpdateOneID(share.ID).
|
||||
AddDownloads(1)
|
||||
if share.RemainDownloads != nil && *share.RemainDownloads >= 0 {
|
||||
stm.AddRemainDownloads(-1)
|
||||
}
|
||||
_, err := stm.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func IsValidShare(share *ent.Share) error {
|
||||
// Check if share is expired
|
||||
if err := IsShareExpired(share); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check owner status
|
||||
owner, err := share.Edges.UserOrErr()
|
||||
if err != nil || owner.Status != user.StatusActive {
|
||||
// Owner already deleted, or not active.
|
||||
return ErrOwnerInactive
|
||||
}
|
||||
|
||||
// Check source file status
|
||||
file, err := share.Edges.FileOrErr()
|
||||
if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID {
|
||||
// Source file already deleted
|
||||
return ErrSourceFileInvalid
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsShareExpired(share *ent.Share) error {
|
||||
// Check if share is expired
|
||||
if (share.Expires != nil && share.Expires.Before(time.Now())) ||
|
||||
(share.RemainDownloads != nil && *share.RemainDownloads <= 0) {
|
||||
return ErrShareLinkExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) {
|
||||
rawQuery := c.listQuery(args)
|
||||
query := withShareEagerLoading(ctx, rawQuery)
|
||||
|
||||
var (
|
||||
shares []*ent.Share
|
||||
err error
|
||||
paginationRes *PaginationResults
|
||||
)
|
||||
if args.UseCursorPagination {
|
||||
shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
|
||||
} else {
|
||||
shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed with paginiation: %w", err)
|
||||
}
|
||||
|
||||
return &ListShareResult{
|
||||
Shares: shares,
|
||||
PaginationResults: paginationRes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getShareOrderOption(args)...)
|
||||
|
||||
var (
|
||||
pageToken *PageToken
|
||||
err error
|
||||
)
|
||||
if args.PageToken != "" {
|
||||
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
||||
}
|
||||
}
|
||||
queryPaged := getShareCursorQuery(args, pageToken, query)
|
||||
|
||||
// Use page size + 1 to determine if there are more items to come
|
||||
queryPaged.Limit(pageSize + 1)
|
||||
|
||||
logs, err := queryPaged.
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// More items to come
|
||||
nextTokenStr := ""
|
||||
if len(logs) > pageSize {
|
||||
lastItem := logs[len(logs)-2]
|
||||
nextToken, err := getShareNextPageToken(c.hasher, lastItem, args)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
||||
}
|
||||
|
||||
nextTokenStr = nextToken
|
||||
}
|
||||
|
||||
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
NextPageToken: nextTokenStr,
|
||||
IsCursor: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getShareOrderOption(args)...)
|
||||
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return logs, &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery {
|
||||
query := c.client.Share.Query()
|
||||
if args.UserID > 0 {
|
||||
query.Where(share.HasUserWith(user.ID(args.UserID)))
|
||||
}
|
||||
|
||||
if args.PublicOnly {
|
||||
query.Where(share.PasswordIsNil())
|
||||
}
|
||||
|
||||
if args.FileID > 0 {
|
||||
query.Where(share.HasFileWith(file.ID(args.FileID)))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// getShareNextPageToken returns the next page token for the given last share.
|
||||
func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) {
|
||||
token := &PageToken{
|
||||
ID: last.ID,
|
||||
}
|
||||
|
||||
return token.Encode(hasher, hashid.EncodeShareID)
|
||||
}
|
||||
|
||||
func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery {
|
||||
o := &sql.OrderTermOptions{}
|
||||
getOrderTerm(args.Order)(o)
|
||||
|
||||
predicates, ok := shareCursorQuery[args.OrderBy]
|
||||
if !ok {
|
||||
predicates = shareCursorQuery[share.FieldID]
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
query.Where(predicates[o.Desc](token))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{
|
||||
share.FieldID: {
|
||||
true: func(token *PageToken) predicate.Share {
|
||||
return share.IDLT(token.ID)
|
||||
},
|
||||
false: func(token *PageToken) predicate.Share {
|
||||
return share.IDGT(token.ID)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func getShareOrderOption(args *ListShareArgs) []share.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case share.FieldViews:
|
||||
return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)}
|
||||
case share.FieldDownloads:
|
||||
return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)}
|
||||
case share.FieldRemainDownloads:
|
||||
return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)}
|
||||
default:
|
||||
return []share.OrderOption{share.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery {
|
||||
if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v {
|
||||
q.WithFile(func(q *ent.FileQuery) {
|
||||
withFileEagerLoading(ctx, q)
|
||||
})
|
||||
}
|
||||
if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v {
|
||||
q.WithUser(func(q *ent.UserQuery) {
|
||||
withUserEagerLoading(ctx, q)
|
||||
})
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
314
inventory/task.go
Normal file
314
inventory/task.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/task"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ctx keys for eager loading options.
|
||||
LoadTaskUser struct{}
|
||||
|
||||
TaskArgs struct {
|
||||
Status task.Status
|
||||
Type string
|
||||
PublicState *types.TaskPublicState
|
||||
PrivateState string
|
||||
OwnerID int
|
||||
CorrelationID uuid.UUID
|
||||
}
|
||||
)
|
||||
|
||||
type TaskClient interface {
|
||||
TxOperator
|
||||
// New creates a new task with the given args.
|
||||
New(ctx context.Context, task *TaskArgs) (*ent.Task, error)
|
||||
// Update updates the task with the given args.
|
||||
Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error)
|
||||
// GetPendingTasks returns all pending tasks of given type.
|
||||
GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error)
|
||||
// GetTaskByID returns the task with the given ID.
|
||||
GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error)
|
||||
// SetCompleteByID sets the task with the given ID to complete.
|
||||
SetCompleteByID(ctx context.Context, taskID int) error
|
||||
// List returns a list of tasks with the given args.
|
||||
List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error)
|
||||
// DeleteByIDs deletes the tasks with the given IDs.
|
||||
DeleteByIDs(ctx context.Context, ids ...int) error
|
||||
}
|
||||
|
||||
type (
|
||||
ListTaskArgs struct {
|
||||
*PaginationArgs
|
||||
Types []string
|
||||
Status []task.Status
|
||||
UserID int
|
||||
CorrelationID *uuid.UUID
|
||||
}
|
||||
|
||||
ListTaskResult struct {
|
||||
*PaginationResults
|
||||
Tasks []*ent.Task
|
||||
}
|
||||
)
|
||||
|
||||
func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient {
|
||||
return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher}
|
||||
}
|
||||
|
||||
type taskClient struct {
|
||||
maxSQlParam int
|
||||
hasher hashid.Encoder
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func (c *taskClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher}
|
||||
}
|
||||
|
||||
func (c *taskClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) {
|
||||
stm := c.client.Task.
|
||||
Create().
|
||||
SetType(task.Type).
|
||||
SetPublicState(task.PublicState)
|
||||
if task.PrivateState != "" {
|
||||
stm.SetPrivateState(task.PrivateState)
|
||||
}
|
||||
|
||||
if task.OwnerID != 0 {
|
||||
stm.SetUserID(task.OwnerID)
|
||||
}
|
||||
|
||||
if task.Status != "" {
|
||||
stm.SetStatus(task.Status)
|
||||
}
|
||||
|
||||
if task.CorrelationID.String() != uuid.Nil.String() {
|
||||
stm.SetCorrelationID(task.CorrelationID)
|
||||
}
|
||||
|
||||
newTask, err := stm.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create task: %w", err)
|
||||
}
|
||||
|
||||
return newTask, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error {
|
||||
_, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) {
|
||||
stm := c.client.Task.UpdateOne(task).
|
||||
SetPublicState(args.PublicState)
|
||||
|
||||
task.PublicState = args.PublicState
|
||||
|
||||
if task.PrivateState != "" {
|
||||
stm.SetPrivateState(task.PrivateState)
|
||||
task.PrivateState = args.PrivateState
|
||||
}
|
||||
|
||||
if task.Status != "" {
|
||||
stm.SetStatus(args.Status)
|
||||
task.Status = args.Status
|
||||
}
|
||||
|
||||
if err := stm.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to create task: %w", err)
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) {
|
||||
tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()).
|
||||
Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)).
|
||||
Where(task.TypeIn(taskType...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Anonymous user is not loaded by default, so we need to load it manually.
|
||||
userClient := NewUserClient(c.client)
|
||||
anonymous, err := userClient.AnonymousUser(ctx)
|
||||
for _, t := range tasks {
|
||||
if t.UserTasks == 0 {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.SetUser(anonymous)
|
||||
}
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) {
|
||||
return withTaskEagerLoading(ctx, c.client.Task.Query()).
|
||||
Where(task.ID(taskID)).
|
||||
First(ctx)
|
||||
}
|
||||
|
||||
func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error {
|
||||
_, err := c.client.Task.UpdateOneID(taskID).
|
||||
SetStatus(task.StatusCompleted).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) {
|
||||
q := c.client.Task.Query()
|
||||
if args.UserID != 0 {
|
||||
q.Where(task.UserTasks(args.UserID))
|
||||
}
|
||||
|
||||
if args.Types != nil {
|
||||
q.Where(task.TypeIn(args.Types...))
|
||||
}
|
||||
|
||||
if args.Status != nil {
|
||||
q.Where(task.StatusIn(args.Status...))
|
||||
}
|
||||
|
||||
if args.CorrelationID != nil {
|
||||
q.Where(task.CorrelationID(*args.CorrelationID))
|
||||
}
|
||||
|
||||
q = withTaskEagerLoading(ctx, q)
|
||||
var (
|
||||
tasks []*ent.Task
|
||||
err error
|
||||
paginationRes *PaginationResults
|
||||
)
|
||||
|
||||
if args.UseCursorPagination {
|
||||
tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1)
|
||||
} else {
|
||||
tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed with paginiation: %w", err)
|
||||
}
|
||||
|
||||
return &ListTaskResult{
|
||||
Tasks: tasks,
|
||||
PaginationResults: paginationRes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(task.ByID(sql.OrderDesc()))
|
||||
|
||||
var (
|
||||
pageToken *PageToken
|
||||
err error
|
||||
queryPaged = query
|
||||
)
|
||||
if args.PageToken != "" {
|
||||
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
||||
}
|
||||
|
||||
queryPaged = query.Where(task.IDLT(pageToken.ID))
|
||||
}
|
||||
|
||||
// Use page size + 1 to determine if there are more items to come
|
||||
queryPaged.Limit(pageSize + 1)
|
||||
|
||||
tasks, err := queryPaged.
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// More items to come
|
||||
nextTokenStr := ""
|
||||
if len(tasks) > pageSize {
|
||||
lastItem := tasks[len(tasks)-2]
|
||||
nextToken, err := getTaskNextPageToken(c.hasher, lastItem)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
||||
}
|
||||
|
||||
nextTokenStr = nextToken
|
||||
}
|
||||
|
||||
return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
NextPageToken: nextTokenStr,
|
||||
IsCursor: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getTaskOrderOption(args)...)
|
||||
|
||||
// Count total items
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
logs, err := query.
|
||||
Limit(pageSize).
|
||||
Offset(args.Page * args.PageSize).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return logs, &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
default:
|
||||
return []task.OrderOption{task.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
// getTaskNextPageToken returns the next page token for the given last task.
|
||||
func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) {
|
||||
token := &PageToken{
|
||||
ID: last.ID,
|
||||
}
|
||||
|
||||
return token.Encode(hasher, hashid.EncodeTaskID)
|
||||
}
|
||||
|
||||
func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery {
|
||||
if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v {
|
||||
q.WithUser(func(q *ent.UserQuery) {
|
||||
withUserEagerLoading(ctx, q)
|
||||
})
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
101
inventory/tx.go
Normal file
101
inventory/tx.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
)
|
||||
|
||||
type TxOperator interface {
|
||||
SetClient(newClient *ent.Client) TxOperator
|
||||
GetClient() *ent.Client
|
||||
}
|
||||
|
||||
type (
|
||||
Tx struct {
|
||||
tx *ent.Tx
|
||||
parent *Tx
|
||||
inherited bool
|
||||
finished bool
|
||||
storageDiff StorageDiff
|
||||
}
|
||||
|
||||
// TxCtx is the context key for inherited transaction
|
||||
TxCtx struct{}
|
||||
)
|
||||
|
||||
// AppendStorageDiff appends the given storage diff to the transaction.
|
||||
func (t *Tx) AppendStorageDiff(diff StorageDiff) {
|
||||
root := t
|
||||
for root.inherited {
|
||||
root = root.parent
|
||||
}
|
||||
|
||||
if root.storageDiff == nil {
|
||||
root.storageDiff = diff
|
||||
} else {
|
||||
root.storageDiff.Merge(diff)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTx wraps the given inventory client with a transaction.
|
||||
func WithTx[T TxOperator](ctx context.Context, c T) (T, *Tx, context.Context, error) {
|
||||
var txClient *ent.Client
|
||||
var txWrapper *Tx
|
||||
|
||||
if txInherited, ok := ctx.Value(TxCtx{}).(*Tx); ok && !txInherited.finished {
|
||||
txWrapper = &Tx{inherited: true, tx: txInherited.tx, parent: txInherited}
|
||||
} else {
|
||||
tx, err := c.GetClient().Tx(ctx)
|
||||
if err != nil {
|
||||
return c, nil, ctx, fmt.Errorf("failed to create transaction: %w", err)
|
||||
}
|
||||
|
||||
txWrapper = &Tx{inherited: false, tx: tx}
|
||||
ctx = context.WithValue(ctx, TxCtx{}, txWrapper)
|
||||
}
|
||||
|
||||
txClient = txWrapper.tx.Client()
|
||||
return c.SetClient(txClient).(T), txWrapper, ctx, nil
|
||||
}
|
||||
|
||||
func Rollback(tx *Tx) error {
|
||||
if !tx.inherited {
|
||||
tx.finished = true
|
||||
return tx.tx.Rollback()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func commit(tx *Tx) (bool, error) {
|
||||
if !tx.inherited {
|
||||
tx.finished = true
|
||||
return true, tx.tx.Commit()
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func Commit(tx *Tx) error {
|
||||
_, err := commit(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
// CommitWithStorageDiff commits the transaction and applies the storage diff, only if the transaction is not inherited.
|
||||
func CommitWithStorageDiff(ctx context.Context, tx *Tx, l logging.Logger, uc UserClient) error {
|
||||
commited, err := commit(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !commited {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := uc.ApplyStorageDiff(ctx, tx.storageDiff); err != nil {
|
||||
l.Error("Failed to apply storage diff", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
224
inventory/types/types.go
Normal file
224
inventory/types/types.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// UserSetting 用户其他配置
|
||||
type (
|
||||
UserSetting struct {
|
||||
ProfileOff bool `json:"profile_off,omitempty"`
|
||||
PreferredTheme string `json:"preferred_theme,omitempty"`
|
||||
VersionRetention bool `json:"version_retention,omitempty"`
|
||||
VersionRetentionExt []string `json:"version_retention_ext,omitempty"`
|
||||
VersionRetentionMax int `json:"version_retention_max,omitempty"`
|
||||
Pined []PinedFile `json:"pined,omitempty"`
|
||||
Language string `json:"email_language,omitempty"`
|
||||
}
|
||||
|
||||
PinedFile struct {
|
||||
Uri string `json:"uri"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// GroupSetting 用户组其他配置
|
||||
GroupSetting struct {
|
||||
CompressSize int64 `json:"compress_size,omitempty"` // 可压缩大小
|
||||
DecompressSize int64 `json:"decompress_size,omitempty"`
|
||||
RemoteDownloadOptions map[string]interface{} `json:"remote_download_options,omitempty"` // 离线下载用户组配置
|
||||
SourceBatchSize int `json:"source_batch,omitempty"`
|
||||
Aria2BatchSize int `json:"aria2_batch,omitempty"`
|
||||
MaxWalkedFiles int `json:"max_walked_files,omitempty"`
|
||||
TrashRetention int `json:"trash_retention,omitempty"`
|
||||
RedirectedSource bool `json:"redirected_source,omitempty"`
|
||||
}
|
||||
|
||||
// PolicySetting 非公有的存储策略属性
|
||||
PolicySetting struct {
|
||||
// Upyun访问Token
|
||||
Token string `json:"token"`
|
||||
// 允许的文件扩展名
|
||||
FileType []string `json:"file_type"`
|
||||
// OauthRedirect Oauth 重定向地址
|
||||
OauthRedirect string `json:"od_redirect,omitempty"`
|
||||
// CustomProxy whether to use custom-proxy to get file content
|
||||
CustomProxy bool `json:"custom_proxy,omitempty"`
|
||||
// ProxyServer 反代地址
|
||||
ProxyServer string `json:"proxy_server,omitempty"`
|
||||
// InternalProxy whether to use Cloudreve internal proxy to get file content
|
||||
InternalProxy bool `json:"internal_proxy,omitempty"`
|
||||
// OdDriver OneDrive 驱动器定位符
|
||||
OdDriver string `json:"od_driver,omitempty"`
|
||||
// Region 区域代码
|
||||
Region string `json:"region,omitempty"`
|
||||
// ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段
|
||||
ServerSideEndpoint string `json:"server_side_endpoint,omitempty"`
|
||||
// 分片上传的分片大小
|
||||
ChunkSize int64 `json:"chunk_size,omitempty"`
|
||||
// 每秒对存储端的 API 请求上限
|
||||
TPSLimit float64 `json:"tps_limit,omitempty"`
|
||||
// 每秒 API 请求爆发上限
|
||||
TPSLimitBurst int `json:"tps_limit_burst,omitempty"`
|
||||
// Set this to `true` to force the request to use path-style addressing,
|
||||
// i.e., `http://s3.amazonaws.com/BUCKET/KEY `
|
||||
S3ForcePathStyle bool `json:"s3_path_style"`
|
||||
// File extensions that support thumbnail generation using native policy API.
|
||||
ThumbExts []string `json:"thumb_exts,omitempty"`
|
||||
// Whether to support all file extensions for thumbnail generation.
|
||||
ThumbSupportAllExts bool `json:"thumb_support_all_exts,omitempty"`
|
||||
// ThumbMaxSize indicates the maximum allowed size of a thumbnail. 0 indicates that no limit is set.
|
||||
ThumbMaxSize int64 `json:"thumb_max_size,omitempty"`
|
||||
// Whether to upload file through server's relay.
|
||||
Relay bool `json:"relay,omitempty"`
|
||||
// Whether to pre allocate space for file before upload in physical disk.
|
||||
PreAllocate bool `json:"pre_allocate,omitempty"`
|
||||
// MediaMetaExts file extensions that support media meta generation using native policy API.
|
||||
MediaMetaExts []string `json:"media_meta_exts,omitempty"`
|
||||
// MediaMetaGeneratorProxy whether to use local proxy to generate media meta.
|
||||
MediaMetaGeneratorProxy bool `json:"media_meta_generator_proxy,omitempty"`
|
||||
// ThumbGeneratorProxy whether to use local proxy to generate thumbnail.
|
||||
ThumbGeneratorProxy bool `json:"thumb_generator_proxy,omitempty"`
|
||||
// NativeMediaProcessing whether to use native media processing API from storage provider.
|
||||
NativeMediaProcessing bool `json:"native_media_processing"`
|
||||
// S3DeleteBatchSize the number of objects to delete in each batch.
|
||||
S3DeleteBatchSize int `json:"s3_delete_batch_size,omitempty"`
|
||||
// StreamSaver whether to use stream saver to download file in Web.
|
||||
StreamSaver bool `json:"stream_saver,omitempty"`
|
||||
// UseCname whether to use CNAME for endpoint (OSS).
|
||||
UseCname bool `json:"use_cname,omitempty"`
|
||||
// CDN domain does not need to be signed.
|
||||
SourceAuth bool `json:"source_auth,omitempty"`
|
||||
}
|
||||
|
||||
FileType int
|
||||
EntityType int
|
||||
GroupPermission int
|
||||
FilePermission int
|
||||
DavAccountOption int
|
||||
NodeCapability int
|
||||
|
||||
NodeSetting struct {
|
||||
Provider DownloaderProvider `json:"provider,omitempty"`
|
||||
*QBittorrentSetting `json:"qbittorrent,omitempty"`
|
||||
*Aria2Setting `json:"aria2,omitempty"`
|
||||
// 下载监控间隔
|
||||
Interval int `json:"interval,omitempty"`
|
||||
WaitForSeeding bool `json:"wait_for_seeding,omitempty"`
|
||||
}
|
||||
|
||||
DownloaderProvider string
|
||||
|
||||
QBittorrentSetting struct {
|
||||
Server string `json:"server,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
TempPath string `json:"temp_path,omitempty"`
|
||||
}
|
||||
|
||||
Aria2Setting struct {
|
||||
Server string `json:"server,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
TempPath string `json:"temp_path,omitempty"`
|
||||
}
|
||||
|
||||
TaskPublicState struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorHistory []string `json:"error_history,omitempty"`
|
||||
ExecutedDuration time.Duration `json:"executed_duration,omitempty"`
|
||||
RetryCount int `json:"retry_count,omitempty"`
|
||||
ResumeTime int64 `json:"resume_time,omitempty"`
|
||||
SlaveTaskProps *SlaveTaskProps `json:"slave_task_props,omitempty"`
|
||||
}
|
||||
|
||||
SlaveTaskProps struct {
|
||||
NodeID int `json:"node_id,omitempty"`
|
||||
MasterSiteURl string `json:"master_site_u_rl,omitempty"`
|
||||
MasterSiteID string `json:"master_site_id,omitempty"`
|
||||
MasterSiteVersion string `json:"master_site_version,omitempty"`
|
||||
}
|
||||
|
||||
EntityRecycleOption struct {
|
||||
UnlinkOnly bool `json:"unlink_only,omitempty"`
|
||||
}
|
||||
|
||||
DavAccountProps struct {
|
||||
}
|
||||
|
||||
PolicyType string
|
||||
|
||||
FileProps struct {
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
GroupPermissionIsAdmin = GroupPermission(iota)
|
||||
GroupPermissionIsAnonymous
|
||||
GroupPermissionShare
|
||||
GroupPermissionWebDAV
|
||||
GroupPermissionArchiveDownload
|
||||
GroupPermissionArchiveTask
|
||||
GroupPermissionWebDAVProxy
|
||||
GroupPermissionShareDownload
|
||||
GroupPermission_CommunityPlaceholder1
|
||||
GroupPermissionRemoteDownload
|
||||
GroupPermission_CommunityPlaceholder2
|
||||
GroupPermissionRedirectedSource // not used
|
||||
GroupPermissionAdvanceDelete
|
||||
GroupPermission_CommunityPlaceholder3
|
||||
GroupPermission_CommunityPlaceholder4
|
||||
GroupPermissionSetExplicitUser_placeholder
|
||||
GroupPermissionIgnoreFileOwnership // not used
|
||||
)
|
||||
|
||||
const (
|
||||
NodeCapabilityNone NodeCapability = iota
|
||||
NodeCapabilityCreateArchive
|
||||
NodeCapabilityExtractArchive
|
||||
NodeCapabilityRemoteDownload
|
||||
NodeCapability_CommunityPlaceholder
|
||||
)
|
||||
|
||||
const (
|
||||
FileTypeFile FileType = iota
|
||||
FileTypeFolder
|
||||
)
|
||||
|
||||
const (
|
||||
EntityTypeVersion EntityType = iota
|
||||
EntityTypeThumbnail
|
||||
EntityTypeLivePhoto
|
||||
)
|
||||
|
||||
func FileTypeFromString(s string) FileType {
|
||||
switch s {
|
||||
case "file":
|
||||
return FileTypeFile
|
||||
case "folder":
|
||||
return FileTypeFolder
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
const (
|
||||
DavAccountReadOnly DavAccountOption = iota
|
||||
DavAccountProxy
|
||||
)
|
||||
|
||||
const (
|
||||
PolicyTypeLocal = "local"
|
||||
PolicyTypeQiniu = "qiniu"
|
||||
PolicyTypeUpyun = "upyun"
|
||||
PolicyTypeOss = "oss"
|
||||
PolicyTypeCos = "cos"
|
||||
PolicyTypeS3 = "s3"
|
||||
PolicyTypeOd = "onedrive"
|
||||
PolicyTypeRemote = "remote"
|
||||
PolicyTypeObs = "obs"
|
||||
)
|
||||
|
||||
const (
|
||||
DownloaderProviderAria2 = DownloaderProvider("aria2")
|
||||
DownloaderProviderQBittorrent = DownloaderProvider("qbittorrent")
|
||||
)
|
||||
594
inventory/user.go
Normal file
594
inventory/user.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/davaccount"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/file"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/passkey"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/schema"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/task"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/user"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ctx keys for eager loading options.
|
||||
LoadUserGroup struct{}
|
||||
LoadUserPasskey struct{}
|
||||
|
||||
UserCtx struct{}
|
||||
UserIDCtx struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserEmailExisted = errors.New("user email has been registered")
|
||||
ErrInactiveUserExisted = errors.New("email already registered but not activated")
|
||||
ErrorUnknownPasswordType = errors.New("unknown password type")
|
||||
ErrorIncorrectPassword = errors.New("incorrect password")
|
||||
ErrInsufficientPoints = errors.New("insufficient points")
|
||||
)
|
||||
|
||||
type (
|
||||
UserClient interface {
|
||||
TxOperator
|
||||
// New creates a new user. If user email registered, existed User will be returned.
|
||||
Create(ctx context.Context, args *NewUserArgs) (*ent.User, error)
|
||||
// GetByEmail get the user with given email, user status is ignored.
|
||||
GetByEmail(ctx context.Context, email string) (*ent.User, error)
|
||||
// GetByID get user by its ID, user status is ignored.
|
||||
GetByID(ctx context.Context, id int) (*ent.User, error)
|
||||
// GetActiveByID get user by its ID, only active user will be returned.
|
||||
GetActiveByID(ctx context.Context, id int) (*ent.User, error)
|
||||
// SetStatus Set user to given status
|
||||
SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error)
|
||||
// AnonymousUser returns the anonymous user.
|
||||
AnonymousUser(ctx context.Context) (*ent.User, error)
|
||||
// GetLoginUserByID returns the login user by its ID. It emits some errors and fallback to anonymous user.
|
||||
GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error)
|
||||
// GetLoginUserByEmail returns the login user by its WebDAV credentials.
|
||||
GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error)
|
||||
// SaveSettings saves user settings.
|
||||
SaveSettings(ctx context.Context, u *ent.User) error
|
||||
// SearchActive search active users by Email or nickname.
|
||||
SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error)
|
||||
// ApplyStorageDiff apply storage diff to user.
|
||||
ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error
|
||||
// UpdateAvatar updates user avatar.
|
||||
UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error)
|
||||
// UpdateNickname updates user nickname.
|
||||
UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error)
|
||||
// UpdatePassword updates user password.
|
||||
UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error)
|
||||
// UpdateTwoFASecret updates user two factor secret.
|
||||
UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error)
|
||||
// ListPasskeys list user's passkeys.
|
||||
ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error)
|
||||
// AddPasskey add passkey to user.
|
||||
AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error)
|
||||
// RemovePasskey remove passkey from user.
|
||||
RemovePasskey(ctx context.Context, uid int, keyId string) error
|
||||
// MarkPasskeyUsed updates passkey used at.
|
||||
MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error
|
||||
// CountByTimeRange count users by time range. Will return all records if start or end is nil.
|
||||
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
|
||||
// ListUsers list users with pagination.
|
||||
ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error)
|
||||
// Upsert upserts a user.
|
||||
Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error)
|
||||
// Delete deletes a user.
|
||||
Delete(ctx context.Context, uid int) error
|
||||
// CalculateStorage calculate user's storage from scratch and update user's storage.
|
||||
CalculateStorage(ctx context.Context, uid int) (int64, error)
|
||||
}
|
||||
ListUserParameters struct {
|
||||
*PaginationArgs
|
||||
GroupID int
|
||||
Status user.Status
|
||||
Nick string
|
||||
Email string
|
||||
}
|
||||
ListUserResult struct {
|
||||
*PaginationResults
|
||||
Users []*ent.User
|
||||
}
|
||||
)
|
||||
|
||||
func NewUserClient(client *ent.Client) UserClient {
|
||||
return &userClient{client: client}
|
||||
}
|
||||
|
||||
type userClient struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
type (
|
||||
// NewUserArgs args to create a new user
|
||||
NewUserArgs struct {
|
||||
Email string
|
||||
Nick string // Optional
|
||||
PlainPassword string
|
||||
Status user.Status
|
||||
GroupID int
|
||||
Avatar string // Optional
|
||||
Language string // Optional
|
||||
}
|
||||
CreateStoragePackArgs struct {
|
||||
UserID int
|
||||
Name string
|
||||
Size int64
|
||||
ExpireAt time.Time
|
||||
}
|
||||
)
|
||||
|
||||
func (c *userClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
|
||||
if start == nil || end == nil {
|
||||
return c.client.User.Query().Count(ctx)
|
||||
}
|
||||
return c.client.User.Query().Where(user.CreatedAtGTE(*start), user.CreatedAtLT(*end)).Count(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error) {
|
||||
return c.client.User.UpdateOne(u).SetNick(name).Save(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error) {
|
||||
return c.client.User.UpdateOne(u).SetAvatar(avatar).Save(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error) {
|
||||
if secret == "" {
|
||||
return c.client.User.UpdateOne(u).ClearTwoFactorSecret().Save(ctx)
|
||||
}
|
||||
return c.client.User.UpdateOne(u).SetTwoFactorSecret(secret).Save(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error) {
|
||||
digest, err := digestPassword(newPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.client.User.UpdateOne(u).SetPassword(digest).Save(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &userClient{client: newClient}
|
||||
}
|
||||
|
||||
func (c *userClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *userClient) ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error) {
|
||||
return c.client.Passkey.Query().Where(passkey.UserID(uid)).All(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error) {
|
||||
return c.client.Passkey.Create().
|
||||
SetName(name).
|
||||
SetCredentialID(base64.StdEncoding.EncodeToString(credential.ID)).
|
||||
SetUserID(uid).
|
||||
SetCredential(credential).
|
||||
Save(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) RemovePasskey(ctx context.Context, uid int, keyId string) error {
|
||||
ctx = schema.SkipSoftDelete(ctx)
|
||||
_, err := c.client.Passkey.Delete().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *userClient) MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error {
|
||||
_, err := c.client.Passkey.Update().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).SetUsedAt(time.Now()).Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *userClient) Delete(ctx context.Context, uid int) error {
|
||||
// Dav accounts
|
||||
if _, err := c.client.DavAccount.Delete().Where(davaccount.OwnerID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil {
|
||||
return fmt.Errorf("failed to delete dav accounts: %w", err)
|
||||
}
|
||||
|
||||
// Passkeys
|
||||
if _, err := c.client.Passkey.Delete().Where(passkey.UserID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil {
|
||||
return fmt.Errorf("failed to delete passkeys: %w", err)
|
||||
}
|
||||
|
||||
// Tasks
|
||||
if _, err := c.client.Task.Delete().Where(task.UserTasks(uid)).Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete tasks: %w", err)
|
||||
}
|
||||
|
||||
return c.client.User.DeleteOneID(uid).Exec(schema.SkipSoftDelete(ctx))
|
||||
}
|
||||
|
||||
func (c *userClient) ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error {
|
||||
ae := serializer.NewAggregateError()
|
||||
for uid, diff := range diffs {
|
||||
if err := c.client.User.Update().Where(user.ID(uid)).AddStorage(diff).Exec(ctx); err != nil {
|
||||
ae.Add(fmt.Sprintf("%d", uid), fmt.Errorf("failed to apply storage diff for user %d: %w", uid, err))
|
||||
}
|
||||
}
|
||||
|
||||
return ae.Aggregate()
|
||||
}
|
||||
|
||||
func (c *userClient) CalculateStorage(ctx context.Context, uid int) (int64, error) {
|
||||
var sum int64
|
||||
batchSize := 5000
|
||||
offset := 0
|
||||
|
||||
for {
|
||||
allFiles, err := c.client.File.Query().
|
||||
Where(file.HasOwnerWith(user.ID(uid))).
|
||||
WithEntities().
|
||||
Offset(offset).
|
||||
Limit(batchSize).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to list user files: %w", err)
|
||||
}
|
||||
|
||||
if len(allFiles) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, file := range allFiles {
|
||||
for _, entity := range file.Edges.Entities {
|
||||
sum += entity.Size
|
||||
}
|
||||
}
|
||||
|
||||
offset += batchSize
|
||||
}
|
||||
|
||||
if _, err := c.client.User.UpdateOneID(uid).SetStorage(sum).Save(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func (c *userClient) SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error) {
|
||||
return c.client.User.UpdateOne(u).SetStatus(status).Save(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) Create(ctx context.Context, args *NewUserArgs) (*ent.User, error) {
|
||||
// Try to check if there's user with same email.
|
||||
if existedUser, err := c.GetByEmail(ctx, args.Email); err == nil {
|
||||
if existedUser.Status == user.StatusInactive {
|
||||
return existedUser, ErrInactiveUserExisted
|
||||
}
|
||||
return existedUser, ErrUserEmailExisted
|
||||
}
|
||||
|
||||
nick := args.Nick
|
||||
if nick == "" {
|
||||
nick = strings.Split(args.Email, "@")[0]
|
||||
}
|
||||
|
||||
userSetting := &types.UserSetting{VersionRetention: true, VersionRetentionMax: 10}
|
||||
query := c.client.User.Create().
|
||||
SetEmail(args.Email).
|
||||
SetNick(nick).
|
||||
SetStatus(args.Status).
|
||||
SetGroupID(args.GroupID).
|
||||
SetAvatar(args.Avatar)
|
||||
|
||||
if args.PlainPassword != "" {
|
||||
pwdDigest, err := digestPassword(args.PlainPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sha256 password: %w", err)
|
||||
}
|
||||
query.SetPassword(pwdDigest)
|
||||
}
|
||||
|
||||
if args.Language != "" {
|
||||
userSetting.Language = args.Language
|
||||
}
|
||||
query.SetSettings(userSetting)
|
||||
|
||||
// Create user
|
||||
newUser, err := query.
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
if newUser.ID == 1 {
|
||||
// For the first user registered, elevate it to admin group.
|
||||
if _, err := newUser.Update().SetGroupID(1).Save(ctx); err != nil {
|
||||
return newUser, fmt.Errorf("failed to elevate user to admin: %w", err)
|
||||
}
|
||||
}
|
||||
return newUser, nil
|
||||
}
|
||||
|
||||
func (c *userClient) GetByEmail(ctx context.Context, email string) (*ent.User, error) {
|
||||
return withUserEagerLoading(ctx, c.client.User.Query().Where(user.EmailEqualFold(email))).First(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) GetByID(ctx context.Context, id int) (*ent.User, error) {
|
||||
return withUserEagerLoading(ctx, c.client.User.Query().Where(user.ID(id))).First(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) GetActiveByID(ctx context.Context, id int) (*ent.User, error) {
|
||||
return withUserEagerLoading(
|
||||
ctx,
|
||||
c.client.User.Query().
|
||||
Where(user.ID(id)).
|
||||
Where(user.StatusEQ(user.StatusActive)),
|
||||
).First(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error) {
|
||||
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
|
||||
return withUserEagerLoading(
|
||||
ctx,
|
||||
c.client.User.Query().
|
||||
Where(user.EmailEqualFold(email)).
|
||||
Where(user.StatusEQ(user.StatusActive)).
|
||||
WithDavAccounts(func(q *ent.DavAccountQuery) {
|
||||
q.Where(davaccount.Password(pwd))
|
||||
}),
|
||||
).First(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error) {
|
||||
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
|
||||
if uid > 0 {
|
||||
expectedUser, err := c.GetActiveByID(ctx, uid)
|
||||
if err == nil {
|
||||
return expectedUser, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get user by id: %w", err)
|
||||
}
|
||||
|
||||
anonymous, err := c.AnonymousUser(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct anonymous user: %w", err)
|
||||
}
|
||||
|
||||
return anonymous, nil
|
||||
}
|
||||
|
||||
func (c *userClient) SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error) {
|
||||
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
|
||||
return withUserEagerLoading(
|
||||
ctx,
|
||||
c.client.User.Query().
|
||||
Where(user.Or(user.EmailContainsFold(keyword), user.NickContainsFold(keyword))).
|
||||
Limit(limit),
|
||||
).All(ctx)
|
||||
}
|
||||
|
||||
func (c *userClient) SaveSettings(ctx context.Context, u *ent.User) error {
|
||||
return c.client.User.UpdateOne(u).SetSettings(u.Settings).Exec(ctx)
|
||||
}
|
||||
|
||||
// UserFromContext get user from context
|
||||
func UserFromContext(ctx context.Context) *ent.User {
|
||||
u, _ := ctx.Value(UserCtx{}).(*ent.User)
|
||||
return u
|
||||
}
|
||||
|
||||
// UserIDFromContext get user id from context.
|
||||
func UserIDFromContext(ctx context.Context) int {
|
||||
uid, ok := ctx.Value(UserIDCtx{}).(int)
|
||||
if !ok {
|
||||
u := UserFromContext(ctx)
|
||||
if u != nil {
|
||||
uid = u.ID
|
||||
}
|
||||
}
|
||||
|
||||
return uid
|
||||
}
|
||||
|
||||
func (c *userClient) AnonymousUser(ctx context.Context) (*ent.User, error) {
|
||||
groupClient := NewGroupClient(c.client, "", nil)
|
||||
anonymousGroup, err := groupClient.AnonymousGroup(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anyonymous group not found: %w", err)
|
||||
}
|
||||
|
||||
// TODO: save into cache
|
||||
anonymous := &ent.User{
|
||||
Settings: &types.UserSetting{},
|
||||
}
|
||||
anonymous.SetGroup(anonymousGroup)
|
||||
return anonymous, nil
|
||||
}
|
||||
|
||||
func (c *userClient) ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error) {
|
||||
query := c.client.User.Query()
|
||||
if args.GroupID != 0 {
|
||||
query = query.Where(user.GroupUsers(args.GroupID))
|
||||
}
|
||||
if args.Status != "" {
|
||||
query = query.Where(user.StatusEQ(args.Status))
|
||||
}
|
||||
if args.Nick != "" {
|
||||
query = query.Where(user.NickContainsFold(args.Nick))
|
||||
}
|
||||
if args.Email != "" {
|
||||
query = query.Where(user.EmailContainsFold(args.Email))
|
||||
}
|
||||
query.Order(getUserOrderOption(args)...)
|
||||
|
||||
// Count total items
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users, err := withUserEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ListUserResult{
|
||||
PaginationResults: &PaginationResults{
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
PageSize: args.PageSize,
|
||||
},
|
||||
Users: users,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *userClient) Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error) {
|
||||
if u.ID == 0 {
|
||||
q := c.client.User.Create().
|
||||
SetEmail(u.Email).
|
||||
SetNick(u.Nick).
|
||||
SetAvatar(u.Avatar).
|
||||
SetStatus(u.Status).
|
||||
SetGroupID(u.GroupUsers).
|
||||
SetPassword(u.Password).
|
||||
SetSettings(&types.UserSetting{})
|
||||
|
||||
if password != "" {
|
||||
pwdDigest, err := digestPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sha256 password: %w", err)
|
||||
}
|
||||
q.SetPassword(pwdDigest)
|
||||
}
|
||||
|
||||
return q.Save(ctx)
|
||||
}
|
||||
|
||||
q := c.client.User.UpdateOne(u).
|
||||
SetEmail(u.Email).
|
||||
SetNick(u.Nick).
|
||||
SetAvatar(u.Avatar).
|
||||
SetStatus(u.Status).
|
||||
SetGroupID(u.GroupUsers)
|
||||
|
||||
if password != "" {
|
||||
pwdDigest, err := digestPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sha256 password: %w", err)
|
||||
}
|
||||
q.SetPassword(pwdDigest)
|
||||
}
|
||||
|
||||
if twoFa != "" {
|
||||
q.ClearTwoFactorSecret()
|
||||
}
|
||||
|
||||
return q.Save(ctx)
|
||||
}
|
||||
|
||||
func getUserOrderOption(args *ListUserParameters) []user.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case user.FieldNick:
|
||||
return []user.OrderOption{user.ByNick(orderTerm), user.ByID(orderTerm)}
|
||||
case user.FieldStorage:
|
||||
return []user.OrderOption{user.ByStorage(orderTerm), user.ByID(orderTerm)}
|
||||
case user.FieldEmail:
|
||||
return []user.OrderOption{user.ByEmail(orderTerm), user.ByID(orderTerm)}
|
||||
case user.FieldUpdatedAt:
|
||||
return []user.OrderOption{user.ByUpdatedAt(orderTerm), user.ByID(orderTerm)}
|
||||
default:
|
||||
return []user.OrderOption{user.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
// IsAnonymousUser check if given user is anonymous user.
|
||||
func IsAnonymousUser(u *ent.User) bool {
|
||||
return u.ID == 0
|
||||
}
|
||||
|
||||
// CheckPassword 根据明文校验密码
|
||||
func CheckPassword(u *ent.User, password string) error {
|
||||
// 根据存储密码拆分为 Salt 和 Digest
|
||||
passwordStore := strings.Split(u.Password, ":")
|
||||
if len(passwordStore) != 2 && len(passwordStore) != 3 {
|
||||
return ErrorUnknownPasswordType
|
||||
}
|
||||
|
||||
// 兼容V2密码,升级后存储格式为: md5:$HASH:$SALT
|
||||
if len(passwordStore) == 3 {
|
||||
if passwordStore[0] != "md5" {
|
||||
return ErrorUnknownPasswordType
|
||||
}
|
||||
hash := md5.New()
|
||||
_, err := hash.Write([]byte(passwordStore[2] + password))
|
||||
bs := hex.EncodeToString(hash.Sum(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bs != passwordStore[1] {
|
||||
return ErrorIncorrectPassword
|
||||
}
|
||||
}
|
||||
|
||||
//计算 Salt 和密码组合的SHA1摘要
|
||||
var hasher hash.Hash
|
||||
if len(passwordStore[1]) == 64 {
|
||||
hasher = sha256.New()
|
||||
} else {
|
||||
// Compatible with V3
|
||||
hasher = sha1.New()
|
||||
}
|
||||
|
||||
_, err := hasher.Write([]byte(password + passwordStore[0]))
|
||||
bs := hex.EncodeToString(hasher.Sum(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bs != passwordStore[1] {
|
||||
return ErrorIncorrectPassword
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func withUserEagerLoading(ctx context.Context, q *ent.UserQuery) *ent.UserQuery {
|
||||
if v, ok := ctx.Value(LoadUserGroup{}).(bool); ok && v {
|
||||
q.WithGroup(func(gq *ent.GroupQuery) {
|
||||
withGroupEagerLoading(ctx, gq)
|
||||
})
|
||||
}
|
||||
if v, ok := ctx.Value(LoadUserPasskey{}).(bool); ok && v {
|
||||
q.WithPasskey()
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func digestPassword(password string) (string, error) {
|
||||
//生成16位 Salt
|
||||
salt := util.RandStringRunes(16)
|
||||
|
||||
//计算 Salt 和密码组合的SHA1摘要
|
||||
hash := sha256.New()
|
||||
_, err := hash.Write([]byte(password + salt))
|
||||
bs := hex.EncodeToString(hash.Sum(nil))
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
//存储 Salt 值和摘要, ":"分割
|
||||
return salt + ":" + string(bs), nil
|
||||
}
|
||||
Reference in New Issue
Block a user