Files
cloudreve/inventory/user.go

596 lines
17 KiB
Go
Raw Normal View History

package inventory
import (
"context"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"hash"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/davaccount"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/ent/passkey"
"github.com/cloudreve/Cloudreve/v4/ent/schema"
"github.com/cloudreve/Cloudreve/v4/ent/task"
"github.com/cloudreve/Cloudreve/v4/ent/user"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/go-webauthn/webauthn/webauthn"
)
type (
// Ctx keys for eager loading options.
LoadUserGroup struct{}
LoadUserPasskey struct{}
UserCtx struct{}
UserIDCtx struct{}
)
var (
ErrUserEmailExisted = errors.New("user email has been registered")
ErrInactiveUserExisted = errors.New("email already registered but not activated")
ErrorUnknownPasswordType = errors.New("unknown password type")
ErrorIncorrectPassword = errors.New("incorrect password")
ErrInsufficientPoints = errors.New("insufficient points")
)
type (
UserClient interface {
TxOperator
// New creates a new user. If user email registered, existed User will be returned.
Create(ctx context.Context, args *NewUserArgs) (*ent.User, error)
// GetByEmail get the user with given email, user status is ignored.
GetByEmail(ctx context.Context, email string) (*ent.User, error)
// GetByID get user by its ID, user status is ignored.
GetByID(ctx context.Context, id int) (*ent.User, error)
// GetActiveByID get user by its ID, only active user will be returned.
GetActiveByID(ctx context.Context, id int) (*ent.User, error)
// SetStatus Set user to given status
SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error)
// AnonymousUser returns the anonymous user.
AnonymousUser(ctx context.Context) (*ent.User, error)
// GetLoginUserByID returns the login user by its ID. It emits some errors and fallback to anonymous user.
GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error)
// GetLoginUserByEmail returns the login user by its WebDAV credentials.
GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error)
// SaveSettings saves user settings.
SaveSettings(ctx context.Context, u *ent.User) error
// SearchActive search active users by Email or nickname.
SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error)
// ApplyStorageDiff apply storage diff to user.
ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error
// UpdateAvatar updates user avatar.
UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error)
// UpdateNickname updates user nickname.
UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error)
// UpdatePassword updates user password.
UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error)
// UpdateTwoFASecret updates user two factor secret.
UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error)
// ListPasskeys list user's passkeys.
ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error)
// AddPasskey add passkey to user.
AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error)
// RemovePasskey remove passkey from user.
RemovePasskey(ctx context.Context, uid int, keyId string) error
// MarkPasskeyUsed updates passkey used at.
MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error
// CountByTimeRange count users by time range. Will return all records if start or end is nil.
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
// ListUsers list users with pagination.
ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error)
// Upsert upserts a user.
Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error)
// Delete deletes a user.
Delete(ctx context.Context, uid int) error
// CalculateStorage calculate user's storage from scratch and update user's storage.
CalculateStorage(ctx context.Context, uid int) (int64, error)
}
ListUserParameters struct {
*PaginationArgs
GroupID int
Status user.Status
Nick string
Email string
}
ListUserResult struct {
*PaginationResults
Users []*ent.User
}
)
func NewUserClient(client *ent.Client) UserClient {
return &userClient{client: client}
}
type userClient struct {
client *ent.Client
}
type (
// NewUserArgs args to create a new user
NewUserArgs struct {
Email string
Nick string // Optional
PlainPassword string
Status user.Status
GroupID int
Avatar string // Optional
Language string // Optional
}
CreateStoragePackArgs struct {
UserID int
Name string
Size int64
ExpireAt time.Time
}
)
func (c *userClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
if start == nil || end == nil {
return c.client.User.Query().Count(ctx)
}
return c.client.User.Query().Where(user.CreatedAtGTE(*start), user.CreatedAtLT(*end)).Count(ctx)
}
func (c *userClient) UpdateNickname(ctx context.Context, u *ent.User, name string) (*ent.User, error) {
return c.client.User.UpdateOne(u).SetNick(name).Save(ctx)
}
func (c *userClient) UpdateAvatar(ctx context.Context, u *ent.User, avatar string) (*ent.User, error) {
return c.client.User.UpdateOne(u).SetAvatar(avatar).Save(ctx)
}
func (c *userClient) UpdateTwoFASecret(ctx context.Context, u *ent.User, secret string) (*ent.User, error) {
if secret == "" {
return c.client.User.UpdateOne(u).ClearTwoFactorSecret().Save(ctx)
}
return c.client.User.UpdateOne(u).SetTwoFactorSecret(secret).Save(ctx)
}
func (c *userClient) UpdatePassword(ctx context.Context, u *ent.User, newPassword string) (*ent.User, error) {
digest, err := digestPassword(newPassword)
if err != nil {
return nil, err
}
return c.client.User.UpdateOne(u).SetPassword(digest).Save(ctx)
}
func (c *userClient) SetClient(newClient *ent.Client) TxOperator {
return &userClient{client: newClient}
}
func (c *userClient) GetClient() *ent.Client {
return c.client
}
func (c *userClient) ListPasskeys(ctx context.Context, uid int) ([]*ent.Passkey, error) {
return c.client.Passkey.Query().Where(passkey.UserID(uid)).All(ctx)
}
func (c *userClient) AddPasskey(ctx context.Context, uid int, name string, credential *webauthn.Credential) (*ent.Passkey, error) {
return c.client.Passkey.Create().
SetName(name).
SetCredentialID(base64.StdEncoding.EncodeToString(credential.ID)).
SetUserID(uid).
SetCredential(credential).
Save(ctx)
}
func (c *userClient) RemovePasskey(ctx context.Context, uid int, keyId string) error {
ctx = schema.SkipSoftDelete(ctx)
_, err := c.client.Passkey.Delete().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).Exec(ctx)
return err
}
func (c *userClient) MarkPasskeyUsed(ctx context.Context, uid int, keyId string) error {
_, err := c.client.Passkey.Update().Where(passkey.UserID(uid), passkey.CredentialID(keyId)).SetUsedAt(time.Now()).Save(ctx)
return err
}
func (c *userClient) Delete(ctx context.Context, uid int) error {
// Dav accounts
if _, err := c.client.DavAccount.Delete().Where(davaccount.OwnerID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil {
return fmt.Errorf("failed to delete dav accounts: %w", err)
}
// Passkeys
if _, err := c.client.Passkey.Delete().Where(passkey.UserID(uid)).Exec(schema.SkipSoftDelete(ctx)); err != nil {
return fmt.Errorf("failed to delete passkeys: %w", err)
}
// Tasks
if _, err := c.client.Task.Delete().Where(task.UserTasks(uid)).Exec(ctx); err != nil {
return fmt.Errorf("failed to delete tasks: %w", err)
}
return c.client.User.DeleteOneID(uid).Exec(schema.SkipSoftDelete(ctx))
}
func (c *userClient) ApplyStorageDiff(ctx context.Context, diffs StorageDiff) error {
ae := serializer.NewAggregateError()
for uid, diff := range diffs {
if err := c.client.User.Update().Where(user.ID(uid)).AddStorage(diff).Exec(ctx); err != nil {
ae.Add(fmt.Sprintf("%d", uid), fmt.Errorf("failed to apply storage diff for user %d: %w", uid, err))
}
}
return ae.Aggregate()
}
func (c *userClient) CalculateStorage(ctx context.Context, uid int) (int64, error) {
var sum int64
batchSize := 30000
offset := 0
for {
allFiles, err := c.client.File.Query().
Where(file.HasOwnerWith(user.ID(uid))).
Where(file.Type(int(types.FileTypeFile))).
WithEntities().
Offset(offset).
Limit(batchSize).
All(ctx)
if err != nil {
return 0, fmt.Errorf("failed to list user files: %w", err)
}
if len(allFiles) == 0 {
break
}
for _, file := range allFiles {
for _, entity := range file.Edges.Entities {
sum += entity.Size
}
}
offset += batchSize
}
if _, err := c.client.User.UpdateOneID(uid).SetStorage(sum).Save(ctx); err != nil {
return 0, err
}
return sum, nil
}
func (c *userClient) SetStatus(ctx context.Context, u *ent.User, status user.Status) (*ent.User, error) {
return c.client.User.UpdateOne(u).SetStatus(status).Save(ctx)
}
func (c *userClient) Create(ctx context.Context, args *NewUserArgs) (*ent.User, error) {
// Try to check if there's user with same email.
if existedUser, err := c.GetByEmail(ctx, args.Email); err == nil {
if existedUser.Status == user.StatusInactive {
return existedUser, ErrInactiveUserExisted
}
return existedUser, ErrUserEmailExisted
}
nick := args.Nick
if nick == "" {
nick = strings.Split(args.Email, "@")[0]
}
userSetting := &types.UserSetting{VersionRetention: true, VersionRetentionMax: 10}
query := c.client.User.Create().
SetEmail(args.Email).
SetNick(nick).
SetStatus(args.Status).
SetGroupID(args.GroupID).
SetAvatar(args.Avatar)
if args.PlainPassword != "" {
pwdDigest, err := digestPassword(args.PlainPassword)
if err != nil {
return nil, fmt.Errorf("failed to sha256 password: %w", err)
}
query.SetPassword(pwdDigest)
}
if args.Language != "" {
userSetting.Language = args.Language
}
query.SetSettings(userSetting)
// Create user
newUser, err := query.
Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
if newUser.ID == 1 {
// For the first user registered, elevate it to admin group.
if _, err := newUser.Update().SetGroupID(1).Save(ctx); err != nil {
return newUser, fmt.Errorf("failed to elevate user to admin: %w", err)
}
}
return newUser, nil
}
func (c *userClient) GetByEmail(ctx context.Context, email string) (*ent.User, error) {
return withUserEagerLoading(ctx, c.client.User.Query().Where(user.EmailEqualFold(email))).First(ctx)
}
func (c *userClient) GetByID(ctx context.Context, id int) (*ent.User, error) {
return withUserEagerLoading(ctx, c.client.User.Query().Where(user.ID(id))).First(ctx)
}
func (c *userClient) GetActiveByID(ctx context.Context, id int) (*ent.User, error) {
return withUserEagerLoading(
ctx,
c.client.User.Query().
Where(user.ID(id)).
Where(user.StatusEQ(user.StatusActive)),
).First(ctx)
}
func (c *userClient) GetActiveByDavAccount(ctx context.Context, email, pwd string) (*ent.User, error) {
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
return withUserEagerLoading(
ctx,
c.client.User.Query().
Where(user.EmailEqualFold(email)).
Where(user.StatusEQ(user.StatusActive)).
WithDavAccounts(func(q *ent.DavAccountQuery) {
q.Where(davaccount.Password(pwd))
}),
).First(ctx)
}
func (c *userClient) GetLoginUserByID(ctx context.Context, uid int) (*ent.User, error) {
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
if uid > 0 {
expectedUser, err := c.GetActiveByID(ctx, uid)
if err == nil {
return expectedUser, nil
}
return nil, fmt.Errorf("failed to get user by id: %w", err)
}
anonymous, err := c.AnonymousUser(ctx)
if err != nil {
return nil, fmt.Errorf("failed to construct anonymous user: %w", err)
}
return anonymous, nil
}
func (c *userClient) SearchActive(ctx context.Context, limit int, keyword string) ([]*ent.User, error) {
ctx = context.WithValue(ctx, LoadUserGroup{}, true)
return withUserEagerLoading(
ctx,
c.client.User.Query().
Where(user.Or(user.EmailContainsFold(keyword), user.NickContainsFold(keyword))).
Limit(limit),
).All(ctx)
}
func (c *userClient) SaveSettings(ctx context.Context, u *ent.User) error {
return c.client.User.UpdateOne(u).SetSettings(u.Settings).Exec(ctx)
}
// UserFromContext get user from context
func UserFromContext(ctx context.Context) *ent.User {
u, _ := ctx.Value(UserCtx{}).(*ent.User)
return u
}
// UserIDFromContext get user id from context.
func UserIDFromContext(ctx context.Context) int {
uid, ok := ctx.Value(UserIDCtx{}).(int)
if !ok {
u := UserFromContext(ctx)
if u != nil {
uid = u.ID
}
}
return uid
}
func (c *userClient) AnonymousUser(ctx context.Context) (*ent.User, error) {
groupClient := NewGroupClient(c.client, "", nil)
anonymousGroup, err := groupClient.AnonymousGroup(ctx)
if err != nil {
return nil, fmt.Errorf("anyonymous group not found: %w", err)
}
// TODO: save into cache
anonymous := &ent.User{
Settings: &types.UserSetting{},
}
anonymous.SetGroup(anonymousGroup)
return anonymous, nil
}
func (c *userClient) ListUsers(ctx context.Context, args *ListUserParameters) (*ListUserResult, error) {
query := c.client.User.Query()
if args.GroupID != 0 {
query = query.Where(user.GroupUsers(args.GroupID))
}
if args.Status != "" {
query = query.Where(user.StatusEQ(args.Status))
}
if args.Nick != "" {
query = query.Where(user.NickContainsFold(args.Nick))
}
if args.Email != "" {
query = query.Where(user.EmailContainsFold(args.Email))
}
query.Order(getUserOrderOption(args)...)
// Count total items
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, err
}
users, err := withUserEagerLoading(ctx, query).Limit(args.PageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, err
}
return &ListUserResult{
PaginationResults: &PaginationResults{
TotalItems: total,
Page: args.Page,
PageSize: args.PageSize,
},
Users: users,
}, nil
}
func (c *userClient) Upsert(ctx context.Context, u *ent.User, password, twoFa string) (*ent.User, error) {
if u.ID == 0 {
q := c.client.User.Create().
SetEmail(u.Email).
SetNick(u.Nick).
SetAvatar(u.Avatar).
SetStatus(u.Status).
SetGroupID(u.GroupUsers).
SetPassword(u.Password).
SetSettings(&types.UserSetting{})
if password != "" {
pwdDigest, err := digestPassword(password)
if err != nil {
return nil, fmt.Errorf("failed to sha256 password: %w", err)
}
q.SetPassword(pwdDigest)
}
return q.Save(ctx)
}
q := c.client.User.UpdateOne(u).
SetEmail(u.Email).
SetNick(u.Nick).
SetAvatar(u.Avatar).
SetStatus(u.Status).
SetGroupID(u.GroupUsers)
if password != "" {
pwdDigest, err := digestPassword(password)
if err != nil {
return nil, fmt.Errorf("failed to sha256 password: %w", err)
}
q.SetPassword(pwdDigest)
}
if twoFa != "" {
q.ClearTwoFactorSecret()
}
return q.Save(ctx)
}
func getUserOrderOption(args *ListUserParameters) []user.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case user.FieldNick:
return []user.OrderOption{user.ByNick(orderTerm), user.ByID(orderTerm)}
case user.FieldStorage:
return []user.OrderOption{user.ByStorage(orderTerm), user.ByID(orderTerm)}
case user.FieldEmail:
return []user.OrderOption{user.ByEmail(orderTerm), user.ByID(orderTerm)}
case user.FieldUpdatedAt:
return []user.OrderOption{user.ByUpdatedAt(orderTerm), user.ByID(orderTerm)}
default:
return []user.OrderOption{user.ByID(orderTerm)}
}
}
// IsAnonymousUser check if given user is anonymous user.
func IsAnonymousUser(u *ent.User) bool {
return u.ID == 0
}
// CheckPassword 根据明文校验密码
func CheckPassword(u *ent.User, password string) error {
// 根据存储密码拆分为 Salt 和 Digest
passwordStore := strings.Split(u.Password, ":")
if len(passwordStore) != 2 && len(passwordStore) != 3 {
return ErrorUnknownPasswordType
}
// 兼容V2密码升级后存储格式为: md5:$HASH:$SALT
if len(passwordStore) == 3 {
if passwordStore[0] != "md5" {
return ErrorUnknownPasswordType
}
hash := md5.New()
_, err := hash.Write([]byte(passwordStore[2] + password))
bs := hex.EncodeToString(hash.Sum(nil))
if err != nil {
return err
}
if bs != passwordStore[1] {
return ErrorIncorrectPassword
}
}
//计算 Salt 和密码组合的SHA1摘要
var hasher hash.Hash
if len(passwordStore[1]) == 64 {
hasher = sha256.New()
} else {
// Compatible with V3
hasher = sha1.New()
}
_, err := hasher.Write([]byte(password + passwordStore[0]))
bs := hex.EncodeToString(hasher.Sum(nil))
if err != nil {
return err
}
if bs != passwordStore[1] {
return ErrorIncorrectPassword
}
return nil
}
func withUserEagerLoading(ctx context.Context, q *ent.UserQuery) *ent.UserQuery {
if v, ok := ctx.Value(LoadUserGroup{}).(bool); ok && v {
q.WithGroup(func(gq *ent.GroupQuery) {
withGroupEagerLoading(ctx, gq)
})
}
if v, ok := ctx.Value(LoadUserPasskey{}).(bool); ok && v {
q.WithPasskey()
}
return q
}
func digestPassword(password string) (string, error) {
//生成16位 Salt
salt := util.RandStringRunes(16)
//计算 Salt 和密码组合的SHA1摘要
hash := sha256.New()
_, err := hash.Write([]byte(password + salt))
bs := hex.EncodeToString(hash.Sum(nil))
if err != nil {
return "", err
}
//存储 Salt 值和摘要, ":"分割
return salt + ":" + string(bs), nil
}