Init V4 community edition (#2265)

* Init V4 community edition

* Init V4 community edition
This commit is contained in:
AaronLiu
2025-04-20 17:31:25 +08:00
committed by GitHub
parent da4e44b77a
commit 21d158db07
597 changed files with 119415 additions and 41692 deletions

219
application/application.go Normal file
View File

@@ -0,0 +1,219 @@
package application
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"os"
"time"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/crontab"
"github.com/cloudreve/Cloudreve/v4/pkg/email"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/onedrive"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/cloudreve/Cloudreve/v4/routers"
"github.com/gin-gonic/gin"
)
type Server interface {
// Start starts the Cloudreve server.
Start() error
PrintBanner()
Close()
}
// NewServer constructs a new Cloudreve server instance with given dependency.
func NewServer(dep dependency.Dep) Server {
return &server{
dep: dep,
logger: dep.Logger(),
config: dep.ConfigProvider(),
}
}
type server struct {
dep dependency.Dep
logger logging.Logger
dbClient *ent.Client
config conf.ConfigProvider
server *http.Server
kv cache.Driver
mailQueue email.Driver
}
func (s *server) PrintBanner() {
fmt.Print(`
___ _ _
/ __\ | ___ _ _ __| |_ __ _____ _____
/ / | |/ _ \| | | |/ _ | '__/ _ \ \ / / _ \
/ /___| | (_) | |_| | (_| | | | __/\ V / __/
\____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___|
V` + constants.BackendVersion + ` Commit #` + constants.LastCommit + ` Pro=` + constants.IsPro + `
================================================
`)
}
func (s *server) Start() error {
// Debug 关闭时,切换为生产模式
if !s.config.System().Debug {
gin.SetMode(gin.ReleaseMode)
}
s.kv = s.dep.KV()
// delete all cached settings
_ = s.kv.Delete(setting.KvSettingPrefix)
// TODO: make sure redis is connected in dep before user traffic.
if s.config.System().Mode == conf.MasterMode {
s.dbClient = s.dep.DBClient()
// TODO: make sure all dep is initialized before server start.
s.dep.LockSystem()
s.dep.UAParser()
// Initialize OneDrive credentials
credentials, err := onedrive.RetrieveOneDriveCredentials(context.Background(), s.dep.StoragePolicyClient())
if err != nil {
return fmt.Errorf("faield to retrieve OneDrive credentials for CredManager: %w", err)
}
if err := s.dep.CredManager().Upsert(context.Background(), credentials...); err != nil {
return fmt.Errorf("failed to upsert OneDrive credentials to CredManager: %w", err)
}
crontab.Register(setting.CronTypeOauthCredRefresh, func(ctx context.Context) {
dep := dependency.FromContext(ctx)
cred := dep.CredManager()
cred.RefreshAll(ctx)
})
// Initialize email queue before user traffic starts.
_ = s.dep.EmailClient(context.Background())
// Start all queues
s.dep.MediaMetaQueue(context.Background()).Start()
s.dep.EntityRecycleQueue(context.Background()).Start()
s.dep.IoIntenseQueue(context.Background()).Start()
s.dep.RemoteDownloadQueue(context.Background()).Start()
// Start cron jobs
c, err := crontab.NewCron(context.Background(), s.dep)
if err != nil {
return err
}
c.Start()
// Start node pool
if _, err := s.dep.NodePool(context.Background()); err != nil {
return err
}
} else {
s.dep.SlaveQueue(context.Background()).Start()
}
s.dep.ThumbQueue(context.Background()).Start()
api := routers.InitRouter(s.dep)
api.TrustedPlatform = s.config.System().ProxyHeader
s.server = &http.Server{Handler: api}
// 如果启用了SSL
if s.config.SSL().CertPath != "" {
s.logger.Info("Listening to %q", s.config.SSL().Listen)
s.server.Addr = s.config.SSL().Listen
if err := s.server.ListenAndServeTLS(s.config.SSL().CertPath, s.config.SSL().KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to listen to %q: %w", s.config.SSL().Listen, err)
}
return nil
}
// 如果启用了Unix
if s.config.Unix().Listen != "" {
// delete socket file before listening
if _, err := os.Stat(s.config.Unix().Listen); err == nil {
if err = os.Remove(s.config.Unix().Listen); err != nil {
return fmt.Errorf("failed to delete socket file %q: %w", s.config.Unix().Listen, err)
}
}
s.logger.Info("Listening to %q", s.config.Unix().Listen)
if err := s.runUnix(s.server); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to listen to %q: %w", s.config.Unix().Listen, err)
}
return nil
}
s.logger.Info("Listening to %q", s.config.System().Listen)
s.server.Addr = s.config.System().Listen
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to listen to %q: %w", s.config.System().Listen, err)
}
return nil
}
func (s *server) Close() {
if s.dbClient != nil {
s.logger.Info("Shutting down database connection...")
if err := s.dbClient.Close(); err != nil {
s.logger.Error("Failed to close database connection: %s", err)
}
}
ctx := context.Background()
if conf.SystemConfig.GracePeriod != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(s.config.System().GracePeriod)*time.Second)
defer cancel()
}
// Shutdown http server
if s.server != nil {
err := s.server.Shutdown(ctx)
if err != nil {
s.logger.Error("Failed to shutdown server: %s", err)
}
}
if s.kv != nil {
if err := s.kv.Persist(util.DataPath(cache.DefaultCacheFile)); err != nil {
s.logger.Warning("Failed to persist cache: %s", err)
}
}
if err := s.dep.Shutdown(ctx); err != nil {
s.logger.Warning("Failed to shutdown dependency manager: %s", err)
}
}
func (s *server) runUnix(server *http.Server) error {
listener, err := net.Listen("unix", s.config.Unix().Listen)
if err != nil {
return err
}
defer listener.Close()
defer os.Remove(s.config.Unix().Listen)
if conf.UnixConfig.Perm > 0 {
err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(s.config.Unix().Perm))
if err != nil {
s.logger.Warning(
"Failed to set permission to %q for socket file %q: %s",
s.config.Unix().Perm,
s.config.Unix().Listen,
err,
)
}
}
return server.Serve(listener)
}

View File

@@ -0,0 +1,34 @@
package constants
// These values will be injected at build time, DO NOT EDIT.
// BackendVersion 当前后端版本号
var BackendVersion = "4.0.0-alpha.1"
// IsPro 是否为Pro版本
var IsPro = "false"
var IsProBool = IsPro == "true"
// LastCommit 最后commit id
var LastCommit = "000000"
const (
APIPrefix = "/api/v4"
APIPrefixSlave = "/api/v4/slave"
CrHeaderPrefix = "X-Cr-"
)
const CloudreveScheme = "cloudreve"
type (
FileSystemType string
)
const (
FileSystemMy = FileSystemType("my")
FileSystemShare = FileSystemType("share")
FileSystemTrash = FileSystemType("trash")
FileSystemSharedWithMe = FileSystemType("shared_with_me")
FileSystemUnknown = FileSystemType("unknown")
)

View File

@@ -0,0 +1,8 @@
package constants
const (
MB = 1 << 20
GB = 1 << 30
TB = 1 << 40
PB = 1 << 50
)

View File

@@ -0,0 +1,874 @@
package dependency
import (
"context"
"errors"
iofs "io/fs"
"net/url"
"sync"
"time"
"github.com/cloudreve/Cloudreve/v4/application/statics"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
"github.com/cloudreve/Cloudreve/v4/pkg/email"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs/mime"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/lock"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/mediameta"
"github.com/cloudreve/Cloudreve/v4/pkg/queue"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/cloudreve/Cloudreve/v4/pkg/thumb"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-contrib/static"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/robfig/cron/v3"
"github.com/samber/lo"
"github.com/ua-parser/uap-go/uaparser"
)
var (
ErrorConfigPathNotSet = errors.New("config path not set")
)
type (
// DepCtx defines keys for dependency manager
DepCtx struct{}
// ReloadCtx force reload new dependency
ReloadCtx struct{}
)
// Dep manages all dependencies of the server application. The default implementation is not
// concurrent safe, so all inner deps should be initialized before any goroutine starts.
type Dep interface {
// ConfigProvider Get a singleton conf.ConfigProvider instance.
ConfigProvider() conf.ConfigProvider
// Logger Get a singleton logging.Logger instance.
Logger() logging.Logger
// Statics Get a singleton fs.FS instance for embedded static resources.
Statics() iofs.FS
// ServerStaticFS Get a singleton static.ServeFileSystem instance for serving static resources.
ServerStaticFS() static.ServeFileSystem
// DBClient Get a singleton ent.Client instance for database access.
DBClient() *ent.Client
// KV Get a singleton cache.Driver instance for KV store.
KV() cache.Driver
// NavigatorStateKV Get a singleton cache.Driver instance for navigator state store. It forces use in-memory
// map instead of Redis to get better performance for complex nested linked list.
NavigatorStateKV() cache.Driver
// SettingClient Get a singleton inventory.SettingClient instance for access DB setting store.
SettingClient() inventory.SettingClient
// SettingProvider Get a singleton setting.Provider instance for access setting store in strong type.
SettingProvider() setting.Provider
// UserClient Creates a new inventory.UserClient instance for access DB user store.
UserClient() inventory.UserClient
// GroupClient Creates a new inventory.GroupClient instance for access DB group store.
GroupClient() inventory.GroupClient
// EmailClient Get a singleton email.Driver instance for sending emails.
EmailClient(ctx context.Context) email.Driver
// GeneralAuth Get a singleton auth.Auth instance for general authentication.
GeneralAuth() auth.Auth
// Shutdown the dependencies gracefully.
Shutdown(ctx context.Context) error
// FileClient Creates a new inventory.FileClient instance for access DB file store.
FileClient() inventory.FileClient
// NodeClient Creates a new inventory.NodeClient instance for access DB node store.
NodeClient() inventory.NodeClient
// DavAccountClient Creates a new inventory.DavAccountClient instance for access DB dav account store.
DavAccountClient() inventory.DavAccountClient
// DirectLinkClient Creates a new inventory.DirectLinkClient instance for access DB direct link store.
DirectLinkClient() inventory.DirectLinkClient
// HashIDEncoder Get a singleton hashid.Encoder instance for encoding/decoding hashids.
HashIDEncoder() hashid.Encoder
// TokenAuth Get a singleton auth.TokenAuth instance for token authentication.
TokenAuth() auth.TokenAuth
// LockSystem Get a singleton lock.LockSystem instance for file lock management.
LockSystem() lock.LockSystem
// ShareClient Creates a new inventory.ShareClient instance for access DB share store.
StoragePolicyClient() inventory.StoragePolicyClient
// RequestClient Creates a new request.Client instance for HTTP requests.
RequestClient(opts ...request.Option) request.Client
// ShareClient Creates a new inventory.ShareClient instance for access DB share store.
ShareClient() inventory.ShareClient
// TaskClient Creates a new inventory.TaskClient instance for access DB task store.
TaskClient() inventory.TaskClient
// ForkWithLogger create a shallow copy of dependency with a new correlated logger, used as per-request dep.
ForkWithLogger(ctx context.Context, l logging.Logger) context.Context
// MediaMetaQueue Get a singleton queue.Queue instance for media metadata processing.
MediaMetaQueue(ctx context.Context) queue.Queue
// SlaveQueue Get a singleton queue.Queue instance for slave tasks.
SlaveQueue(ctx context.Context) queue.Queue
// MediaMetaExtractor Get a singleton mediameta.Extractor instance for media metadata extraction.
MediaMetaExtractor(ctx context.Context) mediameta.Extractor
// ThumbPipeline Get a singleton thumb.Generator instance for chained thumbnail generation.
ThumbPipeline() thumb.Generator
// ThumbQueue Get a singleton queue.Queue instance for thumbnail generation.
ThumbQueue(ctx context.Context) queue.Queue
// EntityRecycleQueue Get a singleton queue.Queue instance for entity recycle.
EntityRecycleQueue(ctx context.Context) queue.Queue
// MimeDetector Get a singleton fs.MimeDetector instance for MIME type detection.
MimeDetector(ctx context.Context) mime.MimeDetector
// CredManager Get a singleton credmanager.CredManager instance for credential management.
CredManager() credmanager.CredManager
// IoIntenseQueue Get a singleton queue.Queue instance for IO intense tasks.
IoIntenseQueue(ctx context.Context) queue.Queue
// RemoteDownloadQueue Get a singleton queue.Queue instance for remote download tasks.
RemoteDownloadQueue(ctx context.Context) queue.Queue
// NodePool Get a singleton cluster.NodePool instance for node pool management.
NodePool(ctx context.Context) (cluster.NodePool, error)
// TaskRegistry Get a singleton queue.TaskRegistry instance for task registration.
TaskRegistry() queue.TaskRegistry
// WebAuthn Get a singleton webauthn.WebAuthn instance for WebAuthn authentication.
WebAuthn(ctx context.Context) (*webauthn.WebAuthn, error)
// UAParser Get a singleton uaparser.Parser instance for user agent parsing.
UAParser() *uaparser.Parser
}
type dependency struct {
configProvider conf.ConfigProvider
logger logging.Logger
statics iofs.FS
serverStaticFS static.ServeFileSystem
dbClient *ent.Client
rawEntClient *ent.Client
kv cache.Driver
navigatorStateKv cache.Driver
settingClient inventory.SettingClient
fileClient inventory.FileClient
shareClient inventory.ShareClient
settingProvider setting.Provider
userClient inventory.UserClient
groupClient inventory.GroupClient
storagePolicyClient inventory.StoragePolicyClient
taskClient inventory.TaskClient
nodeClient inventory.NodeClient
davAccountClient inventory.DavAccountClient
directLinkClient inventory.DirectLinkClient
emailClient email.Driver
generalAuth auth.Auth
hashidEncoder hashid.Encoder
tokenAuth auth.TokenAuth
lockSystem lock.LockSystem
requestClient request.Client
ioIntenseQueue queue.Queue
thumbQueue queue.Queue
mediaMetaQueue queue.Queue
entityRecycleQueue queue.Queue
slaveQueue queue.Queue
remoteDownloadQueue queue.Queue
ioIntenseQueueTask queue.Task
mediaMeta mediameta.Extractor
thumbPipeline thumb.Generator
mimeDetector mime.MimeDetector
credManager credmanager.CredManager
nodePool cluster.NodePool
taskRegistry queue.TaskRegistry
webauthn *webauthn.WebAuthn
parser *uaparser.Parser
cron *cron.Cron
configPath string
isPro bool
requiredDbVersion string
licenseKey string
// Protects inner deps that can be reloaded at runtime.
mu sync.Mutex
}
// NewDependency creates a new Dep instance for construct dependencies.
func NewDependency(opts ...Option) Dep {
d := &dependency{}
for _, o := range opts {
o.apply(d)
}
return d
}
// FromContext retrieves a Dep instance from context.
func FromContext(ctx context.Context) Dep {
return ctx.Value(DepCtx{}).(Dep)
}
func (d *dependency) RequestClient(opts ...request.Option) request.Client {
if d.requestClient != nil {
return d.requestClient
}
return request.NewClient(d.ConfigProvider(), opts...)
}
func (d *dependency) WebAuthn(ctx context.Context) (*webauthn.WebAuthn, error) {
if d.webauthn != nil {
return d.webauthn, nil
}
settings := d.SettingProvider()
siteBasic := settings.SiteBasic(ctx)
wConfig := &webauthn.Config{
RPDisplayName: siteBasic.Name,
RPID: settings.SiteURL(ctx).Hostname(),
RPOrigins: lo.Map(settings.AllSiteURLs(ctx), func(item *url.URL, index int) string {
item.Path = ""
return item.String()
}), // The origin URLs allowed for WebAuthn requests
}
return webauthn.New(wConfig)
}
func (d *dependency) UAParser() *uaparser.Parser {
if d.parser != nil {
return d.parser
}
d.parser = uaparser.NewFromSaved()
return d.parser
}
func (d *dependency) ConfigProvider() conf.ConfigProvider {
if d.configProvider != nil {
return d.configProvider
}
if d.configPath == "" {
d.panicError(ErrorConfigPathNotSet)
}
var err error
d.configProvider, err = conf.NewIniConfigProvider(d.configPath, logging.NewConsoleLogger(logging.LevelInformational))
if err != nil {
d.panicError(err)
}
return d.configProvider
}
func (d *dependency) Logger() logging.Logger {
if d.logger != nil {
return d.logger
}
config := d.ConfigProvider()
logLevel := logging.LogLevel(config.System().LogLevel)
if config.System().Debug {
logLevel = logging.LevelDebug
}
d.logger = logging.NewConsoleLogger(logLevel)
d.logger.Info("Logger initialized with LogLevel=%q.", logLevel)
return d.logger
}
func (d *dependency) Statics() iofs.FS {
if d.statics != nil {
return d.statics
}
d.statics = statics.NewStaticFS(d.Logger())
return d.statics
}
func (d *dependency) ServerStaticFS() static.ServeFileSystem {
if d.serverStaticFS != nil {
return d.serverStaticFS
}
sfs, err := statics.NewServerStaticFS(d.Logger(), d.Statics(), d.isPro)
if err != nil {
d.panicError(err)
}
d.serverStaticFS = sfs
return d.serverStaticFS
}
func (d *dependency) DBClient() *ent.Client {
if d.dbClient != nil {
return d.dbClient
}
if d.rawEntClient == nil {
client, err := inventory.NewRawEntClient(d.Logger(), d.ConfigProvider())
if err != nil {
d.panicError(err)
}
d.rawEntClient = client
}
client, err := inventory.InitializeDBClient(d.Logger(), d.rawEntClient, d.KV(), d.requiredDbVersion)
if err != nil {
d.panicError(err)
}
d.dbClient = client
return d.dbClient
}
func (d *dependency) KV() cache.Driver {
if d.kv != nil {
return d.kv
}
config := d.ConfigProvider().Redis()
if config.Server != "" {
d.kv = cache.NewRedisStore(
d.Logger(),
10,
config.Network,
config.Server,
config.User,
config.Password,
config.DB,
)
} else {
d.kv = cache.NewMemoStore(util.DataPath(cache.DefaultCacheFile), d.Logger())
}
return d.kv
}
func (d *dependency) NavigatorStateKV() cache.Driver {
if d.navigatorStateKv != nil {
return d.navigatorStateKv
}
d.navigatorStateKv = cache.NewMemoStore("", d.Logger())
return d.navigatorStateKv
}
func (d *dependency) SettingClient() inventory.SettingClient {
if d.settingClient != nil {
return d.settingClient
}
d.settingClient = inventory.NewSettingClient(d.DBClient(), d.KV())
return d.settingClient
}
func (d *dependency) SettingProvider() setting.Provider {
if d.settingProvider != nil {
return d.settingProvider
}
if d.ConfigProvider().System().Mode == conf.MasterMode {
// For master mode, setting value will be retrieved in order:
// Env overwrite -> KV Store -> DB Setting Store
d.settingProvider = setting.NewProvider(
setting.NewEnvOverrideStore(
setting.NewKvSettingStore(d.KV(),
setting.NewDbSettingStore(d.SettingClient(), nil),
),
d.Logger(),
),
)
} else {
// For slave mode, setting value will be retrieved in order:
// Env overwrite -> Config file overwrites -> Setting defaults in DB schema
d.settingProvider = setting.NewProvider(
setting.NewEnvOverrideStore(
setting.NewConfSettingStore(d.ConfigProvider(),
setting.NewDbDefaultStore(nil),
),
d.Logger(),
),
)
}
return d.settingProvider
}
func (d *dependency) UserClient() inventory.UserClient {
if d.userClient != nil {
return d.userClient
}
return inventory.NewUserClient(d.DBClient())
}
func (d *dependency) GroupClient() inventory.GroupClient {
if d.groupClient != nil {
return d.groupClient
}
return inventory.NewGroupClient(d.DBClient(), d.ConfigProvider().Database().Type, d.KV())
}
func (d *dependency) NodeClient() inventory.NodeClient {
if d.nodeClient != nil {
return d.nodeClient
}
return inventory.NewNodeClient(d.DBClient())
}
func (d *dependency) NodePool(ctx context.Context) (cluster.NodePool, error) {
reload, _ := ctx.Value(ReloadCtx{}).(bool)
if d.nodePool != nil && !reload {
return d.nodePool, nil
}
if d.ConfigProvider().System().Mode == conf.MasterMode {
np, err := cluster.NewNodePool(ctx, d.Logger(), d.ConfigProvider(), d.SettingProvider(), d.NodeClient())
if err != nil {
return nil, err
}
d.nodePool = np
} else {
d.nodePool = cluster.NewSlaveDummyNodePool(ctx, d.ConfigProvider(), d.SettingProvider())
}
return d.nodePool, nil
}
func (d *dependency) EmailClient(ctx context.Context) email.Driver {
d.mu.Lock()
defer d.mu.Unlock()
if reload, _ := ctx.Value(ReloadCtx{}).(bool); reload || d.emailClient == nil {
if d.emailClient != nil {
d.emailClient.Close()
}
d.emailClient = email.NewSMTPPool(d.SettingProvider(), d.Logger())
}
return d.emailClient
}
func (d *dependency) MimeDetector(ctx context.Context) mime.MimeDetector {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.mimeDetector != nil && !reload {
return d.mimeDetector
}
d.mimeDetector = mime.NewMimeDetector(ctx, d.SettingProvider(), d.Logger())
return d.mimeDetector
}
func (d *dependency) MediaMetaExtractor(ctx context.Context) mediameta.Extractor {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.mediaMeta != nil && !reload {
return d.mediaMeta
}
d.mediaMeta = mediameta.NewExtractorManager(ctx, d.SettingProvider(), d.Logger())
return d.mediaMeta
}
func (d *dependency) ThumbQueue(ctx context.Context) queue.Queue {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.thumbQueue != nil && !reload {
return d.thumbQueue
}
if d.thumbQueue != nil {
d.thumbQueue.Shutdown()
}
settings := d.SettingProvider()
queueSetting := settings.Queue(context.Background(), setting.QueueTypeThumb)
var (
t inventory.TaskClient
)
if d.ConfigProvider().System().Mode == conf.MasterMode {
t = d.TaskClient()
}
d.thumbQueue = queue.New(d.Logger(), t, nil, d,
queue.WithBackoffFactor(queueSetting.BackoffFactor),
queue.WithMaxRetry(queueSetting.MaxRetry),
queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration),
queue.WithRetryDelay(queueSetting.RetryDelay),
queue.WithWorkerCount(queueSetting.WorkerNum),
queue.WithName("ThumbQueue"),
queue.WithMaxTaskExecution(queueSetting.MaxExecution),
)
return d.thumbQueue
}
func (d *dependency) MediaMetaQueue(ctx context.Context) queue.Queue {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.mediaMetaQueue != nil && !reload {
return d.mediaMetaQueue
}
if d.mediaMetaQueue != nil {
d.mediaMetaQueue.Shutdown()
}
settings := d.SettingProvider()
queueSetting := settings.Queue(context.Background(), setting.QueueTypeMediaMeta)
d.mediaMetaQueue = queue.New(d.Logger(), d.TaskClient(), nil, d,
queue.WithBackoffFactor(queueSetting.BackoffFactor),
queue.WithMaxRetry(queueSetting.MaxRetry),
queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration),
queue.WithRetryDelay(queueSetting.RetryDelay),
queue.WithWorkerCount(queueSetting.WorkerNum),
queue.WithName("MediaMetadataQueue"),
queue.WithMaxTaskExecution(queueSetting.MaxExecution),
queue.WithResumeTaskType(queue.MediaMetaTaskType),
)
return d.mediaMetaQueue
}
func (d *dependency) IoIntenseQueue(ctx context.Context) queue.Queue {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.ioIntenseQueue != nil && !reload {
return d.ioIntenseQueue
}
if d.ioIntenseQueue != nil {
d.ioIntenseQueue.Shutdown()
}
settings := d.SettingProvider()
queueSetting := settings.Queue(context.Background(), setting.QueueTypeIOIntense)
d.ioIntenseQueue = queue.New(d.Logger(), d.TaskClient(), d.TaskRegistry(), d,
queue.WithBackoffFactor(queueSetting.BackoffFactor),
queue.WithMaxRetry(queueSetting.MaxRetry),
queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration),
queue.WithRetryDelay(queueSetting.RetryDelay),
queue.WithWorkerCount(queueSetting.WorkerNum),
queue.WithName("IoIntenseQueue"),
queue.WithMaxTaskExecution(queueSetting.MaxExecution),
queue.WithResumeTaskType(queue.CreateArchiveTaskType, queue.ExtractArchiveTaskType, queue.RelocateTaskType),
queue.WithTaskPullInterval(10*time.Second),
)
return d.ioIntenseQueue
}
func (d *dependency) RemoteDownloadQueue(ctx context.Context) queue.Queue {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.remoteDownloadQueue != nil && !reload {
return d.remoteDownloadQueue
}
if d.remoteDownloadQueue != nil {
d.remoteDownloadQueue.Shutdown()
}
settings := d.SettingProvider()
queueSetting := settings.Queue(context.Background(), setting.QueueTypeRemoteDownload)
d.remoteDownloadQueue = queue.New(d.Logger(), d.TaskClient(), d.TaskRegistry(), d,
queue.WithBackoffFactor(queueSetting.BackoffFactor),
queue.WithMaxRetry(queueSetting.MaxRetry),
queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration),
queue.WithRetryDelay(queueSetting.RetryDelay),
queue.WithWorkerCount(queueSetting.WorkerNum),
queue.WithName("RemoteDownloadQueue"),
queue.WithMaxTaskExecution(queueSetting.MaxExecution),
queue.WithResumeTaskType(queue.RemoteDownloadTaskType),
queue.WithTaskPullInterval(20*time.Second),
)
return d.remoteDownloadQueue
}
func (d *dependency) EntityRecycleQueue(ctx context.Context) queue.Queue {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.entityRecycleQueue != nil && !reload {
return d.entityRecycleQueue
}
if d.entityRecycleQueue != nil {
d.entityRecycleQueue.Shutdown()
}
settings := d.SettingProvider()
queueSetting := settings.Queue(context.Background(), setting.QueueTypeEntityRecycle)
d.entityRecycleQueue = queue.New(d.Logger(), d.TaskClient(), nil, d,
queue.WithBackoffFactor(queueSetting.BackoffFactor),
queue.WithMaxRetry(queueSetting.MaxRetry),
queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration),
queue.WithRetryDelay(queueSetting.RetryDelay),
queue.WithWorkerCount(queueSetting.WorkerNum),
queue.WithName("EntityRecycleQueue"),
queue.WithMaxTaskExecution(queueSetting.MaxExecution),
queue.WithResumeTaskType(queue.EntityRecycleRoutineTaskType, queue.ExplicitEntityRecycleTaskType, queue.UploadSentinelCheckTaskType),
queue.WithTaskPullInterval(10*time.Second),
)
return d.entityRecycleQueue
}
func (d *dependency) SlaveQueue(ctx context.Context) queue.Queue {
d.mu.Lock()
defer d.mu.Unlock()
_, reload := ctx.Value(ReloadCtx{}).(bool)
if d.slaveQueue != nil && !reload {
return d.slaveQueue
}
if d.slaveQueue != nil {
d.slaveQueue.Shutdown()
}
settings := d.SettingProvider()
queueSetting := settings.Queue(context.Background(), setting.QueueTypeSlave)
d.slaveQueue = queue.New(d.Logger(), nil, nil, d,
queue.WithBackoffFactor(queueSetting.BackoffFactor),
queue.WithMaxRetry(queueSetting.MaxRetry),
queue.WithBackoffMaxDuration(queueSetting.BackoffMaxDuration),
queue.WithRetryDelay(queueSetting.RetryDelay),
queue.WithWorkerCount(queueSetting.WorkerNum),
queue.WithName("SlaveQueue"),
queue.WithMaxTaskExecution(queueSetting.MaxExecution),
)
return d.slaveQueue
}
func (d *dependency) GeneralAuth() auth.Auth {
if d.generalAuth != nil {
return d.generalAuth
}
var secretKey string
if d.ConfigProvider().System().Mode == conf.MasterMode {
secretKey = d.SettingProvider().SecretKey(context.Background())
} else {
secretKey = d.ConfigProvider().Slave().Secret
if secretKey == "" {
d.panicError(errors.New("SlaveSecret is not set, please specify it in config file"))
}
}
d.generalAuth = auth.HMACAuth{
SecretKey: []byte(secretKey),
}
return d.generalAuth
}
func (d *dependency) FileClient() inventory.FileClient {
if d.fileClient != nil {
return d.fileClient
}
return inventory.NewFileClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder())
}
func (d *dependency) ShareClient() inventory.ShareClient {
if d.shareClient != nil {
return d.shareClient
}
return inventory.NewShareClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder())
}
func (d *dependency) TaskClient() inventory.TaskClient {
if d.taskClient != nil {
return d.taskClient
}
return inventory.NewTaskClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder())
}
func (d *dependency) DavAccountClient() inventory.DavAccountClient {
if d.davAccountClient != nil {
return d.davAccountClient
}
return inventory.NewDavAccountClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder())
}
func (d *dependency) DirectLinkClient() inventory.DirectLinkClient {
if d.directLinkClient != nil {
return d.directLinkClient
}
return inventory.NewDirectLinkClient(d.DBClient(), d.ConfigProvider().Database().Type, d.HashIDEncoder())
}
func (d *dependency) HashIDEncoder() hashid.Encoder {
if d.hashidEncoder != nil {
return d.hashidEncoder
}
encoder, err := hashid.New(d.SettingProvider().HashIDSalt(context.Background()))
if err != nil {
d.panicError(err)
}
d.hashidEncoder = encoder
return d.hashidEncoder
}
func (d *dependency) CredManager() credmanager.CredManager {
if d.credManager != nil {
return d.credManager
}
if d.ConfigProvider().System().Mode == conf.MasterMode {
d.credManager = credmanager.New(d.KV())
} else {
d.credManager = credmanager.NewSlaveManager(d.KV(), d.ConfigProvider())
}
return d.credManager
}
func (d *dependency) TokenAuth() auth.TokenAuth {
if d.tokenAuth != nil {
return d.tokenAuth
}
d.tokenAuth = auth.NewTokenAuth(d.HashIDEncoder(), d.SettingProvider(),
[]byte(d.SettingProvider().SecretKey(context.Background())), d.UserClient(), d.Logger())
return d.tokenAuth
}
func (d *dependency) LockSystem() lock.LockSystem {
if d.lockSystem != nil {
return d.lockSystem
}
d.lockSystem = lock.NewMemLS(d.HashIDEncoder(), d.Logger())
return d.lockSystem
}
func (d *dependency) StoragePolicyClient() inventory.StoragePolicyClient {
if d.storagePolicyClient != nil {
return d.storagePolicyClient
}
return inventory.NewStoragePolicyClient(d.DBClient(), d.KV())
}
func (d *dependency) ThumbPipeline() thumb.Generator {
if d.thumbPipeline != nil {
return d.thumbPipeline
}
d.thumbPipeline = thumb.NewPipeline(d.SettingProvider(), d.Logger())
return d.thumbPipeline
}
func (d *dependency) TaskRegistry() queue.TaskRegistry {
if d.taskRegistry != nil {
return d.taskRegistry
}
d.taskRegistry = queue.NewTaskRegistry()
return d.taskRegistry
}
func (d *dependency) Shutdown(ctx context.Context) error {
d.mu.Lock()
if d.emailClient != nil {
d.emailClient.Close()
}
wg := sync.WaitGroup{}
if d.mediaMetaQueue != nil {
wg.Add(1)
go func() {
d.mediaMetaQueue.Shutdown()
defer wg.Done()
}()
}
if d.thumbQueue != nil {
wg.Add(1)
go func() {
d.thumbQueue.Shutdown()
defer wg.Done()
}()
}
if d.ioIntenseQueue != nil {
wg.Add(1)
go func() {
d.ioIntenseQueue.Shutdown()
defer wg.Done()
}()
}
if d.entityRecycleQueue != nil {
wg.Add(1)
go func() {
d.entityRecycleQueue.Shutdown()
defer wg.Done()
}()
}
if d.slaveQueue != nil {
wg.Add(1)
go func() {
d.slaveQueue.Shutdown()
defer wg.Done()
}()
}
if d.remoteDownloadQueue != nil {
wg.Add(1)
go func() {
d.remoteDownloadQueue.Shutdown()
defer wg.Done()
}()
}
d.mu.Unlock()
wg.Wait()
return nil
}
func (d *dependency) panicError(err error) {
if d.logger != nil {
d.logger.Panic("Fatal error in dependency initialization: %s", err)
}
panic(err)
}
func (d *dependency) ForkWithLogger(ctx context.Context, l logging.Logger) context.Context {
dep := &dependencyCorrelated{
l: l,
dependency: d,
}
return context.WithValue(ctx, DepCtx{}, dep)
}
type dependencyCorrelated struct {
l logging.Logger
*dependency
}
func (d *dependencyCorrelated) Logger() logging.Logger {
return d.l
}

View File

@@ -0,0 +1,165 @@
package dependency
import (
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/email"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/gin-contrib/static"
"io/fs"
)
// Option 发送请求的额外设置
type Option interface {
apply(*dependency)
}
type optionFunc func(*dependency)
func (f optionFunc) apply(o *dependency) {
f(o)
}
// WithConfigPath Set the path of the config file.
func WithConfigPath(p string) Option {
return optionFunc(func(o *dependency) {
o.configPath = p
})
}
// WithLogger Set the default logging.
func WithLogger(l logging.Logger) Option {
return optionFunc(func(o *dependency) {
o.logger = l
})
}
// WithConfigProvider Set the default config provider.
func WithConfigProvider(c conf.ConfigProvider) Option {
return optionFunc(func(o *dependency) {
o.configProvider = c
})
}
// WithStatics Set the default statics FS.
func WithStatics(c fs.FS) Option {
return optionFunc(func(o *dependency) {
o.statics = c
})
}
// WithServerStaticFS Set the default statics FS for server.
func WithServerStaticFS(c static.ServeFileSystem) Option {
return optionFunc(func(o *dependency) {
o.serverStaticFS = c
})
}
// WithProFlag Set if current instance is a pro version.
func WithProFlag(c bool) Option {
return optionFunc(func(o *dependency) {
o.isPro = c
})
}
func WithLicenseKey(c string) Option {
return optionFunc(func(o *dependency) {
o.licenseKey = c
})
}
// WithRawEntClient Set the default raw ent client.
func WithRawEntClient(c *ent.Client) Option {
return optionFunc(func(o *dependency) {
o.rawEntClient = c
})
}
// WithDbClient Set the default ent client.
func WithDbClient(c *ent.Client) Option {
return optionFunc(func(o *dependency) {
o.dbClient = c
})
}
// WithRequiredDbVersion Set the required db version.
func WithRequiredDbVersion(c string) Option {
return optionFunc(func(o *dependency) {
o.requiredDbVersion = c
})
}
// WithKV Set the default KV store driverold
func WithKV(c cache.Driver) Option {
return optionFunc(func(o *dependency) {
o.kv = c
})
}
// WithSettingClient Set the default setting client
func WithSettingClient(s inventory.SettingClient) Option {
return optionFunc(func(o *dependency) {
o.settingClient = s
})
}
// WithSettingProvider Set the default setting provider
func WithSettingProvider(s setting.Provider) Option {
return optionFunc(func(o *dependency) {
o.settingProvider = s
})
}
// WithUserClient Set the default user client
func WithUserClient(s inventory.UserClient) Option {
return optionFunc(func(o *dependency) {
o.userClient = s
})
}
// WithEmailClient Set the default email client
func WithEmailClient(s email.Driver) Option {
return optionFunc(func(o *dependency) {
o.emailClient = s
})
}
// WithGeneralAuth Set the default general auth
func WithGeneralAuth(s auth.Auth) Option {
return optionFunc(func(o *dependency) {
o.generalAuth = s
})
}
// WithHashIDEncoder Set the default hash id encoder
func WithHashIDEncoder(s hashid.Encoder) Option {
return optionFunc(func(o *dependency) {
o.hashidEncoder = s
})
}
// WithTokenAuth Set the default token auth
func WithTokenAuth(s auth.TokenAuth) Option {
return optionFunc(func(o *dependency) {
o.tokenAuth = s
})
}
// WithFileClient Set the default file client
func WithFileClient(s inventory.FileClient) Option {
return optionFunc(func(o *dependency) {
o.fileClient = s
})
}
// WithShareClient Set the default share client
func WithShareClient(s inventory.ShareClient) Option {
return optionFunc(func(o *dependency) {
o.shareClient = s
})
}

View File

@@ -0,0 +1,47 @@
package migrator
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
)
func migrateAvatars(m *Migrator) error {
m.l.Info("Migrating avatars files...")
avatarRoot := util.RelativePath(m.state.V3AvatarPath)
for uid, _ := range m.state.UserIDs {
avatarPath := filepath.Join(avatarRoot, fmt.Sprintf("avatar_%d_2.png", uid))
// check if file exists
if util.Exists(avatarPath) {
m.l.Info("Migrating avatar for user %d", uid)
// Copy to v4 avatar path
v4Path := filepath.Join(util.DataPath("avatar"), fmt.Sprintf("avatar_%d.png", uid))
// copy
origin, err := os.Open(avatarPath)
if err != nil {
return fmt.Errorf("failed to open avatar file: %w", err)
}
defer origin.Close()
dest, err := util.CreatNestedFile(v4Path)
if err != nil {
return fmt.Errorf("failed to create avatar file: %w", err)
}
defer dest.Close()
_, err = io.Copy(dest, origin)
if err != nil {
m.l.Warning("Failed to copy avatar file: %s, skipping...", err)
}
}
}
return nil
}

View File

@@ -0,0 +1,124 @@
package conf
import (
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/go-ini/ini"
"github.com/go-playground/validator/v10"
)
// database 数据库
type database struct {
Type string
User string
Password string
Host string
Name string
TablePrefix string
DBFile string
Port int
Charset string
UnixSocket bool
}
// system 系统通用配置
type system struct {
Mode string `validate:"eq=master|eq=slave"`
Listen string `validate:"required"`
Debug bool
SessionSecret string
HashIDSalt string
GracePeriod int `validate:"gte=0"`
ProxyHeader string `validate:"required_with=Listen"`
}
type ssl struct {
CertPath string `validate:"omitempty,required"`
KeyPath string `validate:"omitempty,required"`
Listen string `validate:"required"`
}
type unix struct {
Listen string
Perm uint32
}
// slave 作为slave存储端配置
type slave struct {
Secret string `validate:"omitempty,gte=64"`
CallbackTimeout int `validate:"omitempty,gte=1"`
SignatureTTL int `validate:"omitempty,gte=1"`
}
// redis 配置
type redis struct {
Network string
Server string
User string
Password string
DB string
}
// 跨域配置
type cors struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
AllowCredentials bool
ExposeHeaders []string
SameSite string
Secure bool
}
var cfg *ini.File
// Init 初始化配置文件
func Init(l logging.Logger, path string) error {
var err error
cfg, err = ini.Load(path)
if err != nil {
l.Error("Failed to parse config file %q: %s", path, err)
return err
}
sections := map[string]interface{}{
"Database": DatabaseConfig,
"System": SystemConfig,
"SSL": SSLConfig,
"UnixSocket": UnixConfig,
"Redis": RedisConfig,
"CORS": CORSConfig,
"Slave": SlaveConfig,
}
for sectionName, sectionStruct := range sections {
err = mapSection(sectionName, sectionStruct)
if err != nil {
l.Error("Failed to parse config section %q: %s", sectionName, err)
return err
}
}
// 映射数据库配置覆盖
for _, key := range cfg.Section("OptionOverwrite").Keys() {
OptionOverwrite[key.Name()] = key.Value()
}
return nil
}
// mapSection 将配置文件的 Section 映射到结构体上
func mapSection(section string, confStruct interface{}) error {
err := cfg.Section(section).MapTo(confStruct)
if err != nil {
return err
}
// 验证合法性
validate := validator.New()
err = validate.Struct(confStruct)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,55 @@
package conf
// RedisConfig Redis服务器配置
var RedisConfig = &redis{
Network: "tcp",
Server: "",
Password: "",
DB: "0",
}
// DatabaseConfig 数据库配置
var DatabaseConfig = &database{
Type: "UNSET",
Charset: "utf8",
DBFile: "cloudreve.db",
Port: 3306,
UnixSocket: false,
}
// SystemConfig 系统公用配置
var SystemConfig = &system{
Debug: false,
Mode: "master",
Listen: ":5212",
ProxyHeader: "X-Forwarded-For",
}
// CORSConfig 跨域配置
var CORSConfig = &cors{
AllowOrigins: []string{"UNSET"},
AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"},
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"},
AllowCredentials: false,
ExposeHeaders: nil,
SameSite: "Default",
Secure: false,
}
// SlaveConfig 从机配置
var SlaveConfig = &slave{
CallbackTimeout: 20,
SignatureTTL: 60,
}
var SSLConfig = &ssl{
Listen: ":443",
CertPath: "",
KeyPath: "",
}
var UnixConfig = &unix{
Listen: "",
}
var OptionOverwrite = map[string]interface{}{}

View File

@@ -0,0 +1,82 @@
package migrator
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
)
func (m *Migrator) migrateDirectLink() error {
m.l.Info("Migrating direct links...")
batchSize := 1000
offset := m.state.DirectLinkOffset
ctx := context.Background()
if m.state.DirectLinkOffset > 0 {
m.l.Info("Resuming direct link migration from offset %d", offset)
}
for {
m.l.Info("Migrating direct links with offset %d", offset)
var directLinks []model.SourceLink
if err := model.DB.Limit(batchSize).Offset(offset).Find(&directLinks).Error; err != nil {
return fmt.Errorf("failed to list v3 direct links: %w", err)
}
if len(directLinks) == 0 {
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting direct link ID sequence for postgres...")
m.v4client.DirectLink.ExecContext(ctx, "SELECT SETVAL('direct_links_id_seq', (SELECT MAX(id) FROM direct_links))")
}
break
}
tx, err := m.v4client.Tx(ctx)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
for _, dl := range directLinks {
sourceId := int(dl.FileID) + m.state.LastFolderID
// check if file exists
_, err = tx.File.Query().Where(file.ID(sourceId)).First(ctx)
if err != nil {
m.l.Warning("File %d not found, skipping direct link %d", sourceId, dl.ID)
continue
}
stm := tx.DirectLink.Create().
SetCreatedAt(formatTime(dl.CreatedAt)).
SetUpdatedAt(formatTime(dl.UpdatedAt)).
SetRawID(int(dl.ID)).
SetFileID(sourceId).
SetName(dl.Name).
SetDownloads(dl.Downloads).
SetSpeed(0)
if _, err := stm.Save(ctx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to create direct link %d: %w", dl.ID, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
offset += batchSize
m.state.DirectLinkOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after direct link batch: %s", err)
} else {
m.l.Info("Saved migration state after processing this batch")
}
}
return nil
}

View File

@@ -0,0 +1,189 @@
package migrator
import (
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
)
func (m *Migrator) migrateFile() error {
m.l.Info("Migrating files...")
batchSize := 1000
offset := m.state.FileOffset
ctx := context.Background()
if m.state.FileConflictRename == nil {
m.state.FileConflictRename = make(map[uint]string)
}
if m.state.EntitySources == nil {
m.state.EntitySources = make(map[string]int)
}
if offset > 0 {
m.l.Info("Resuming file migration from offset %d", offset)
}
out:
for {
m.l.Info("Migrating files with offset %d", offset)
var files []model.File
if err := model.DB.Limit(batchSize).Offset(offset).Find(&files).Error; err != nil {
return fmt.Errorf("failed to list v3 files: %w", err)
}
if len(files) == 0 {
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting file ID sequence for postgres...")
m.v4client.File.ExecContext(ctx, "SELECT SETVAL('files_id_seq', (SELECT MAX(id) FROM files))")
}
break
}
tx, err := m.v4client.Tx(ctx)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
for _, f := range files {
if _, ok := m.state.FolderIDs[int(f.FolderID)]; !ok {
m.l.Warning("Folder ID %d for file %d not found, skipping", f.FolderID, f.ID)
continue
}
if _, ok := m.state.UserIDs[int(f.UserID)]; !ok {
m.l.Warning("User ID %d for file %d not found, skipping", f.UserID, f.ID)
continue
}
if _, ok := m.state.PolicyIDs[int(f.PolicyID)]; !ok {
m.l.Warning("Policy ID %d for file %d not found, skipping", f.PolicyID, f.ID)
continue
}
metadata := make(map[string]string)
if f.Metadata != "" {
json.Unmarshal([]byte(f.Metadata), &metadata)
}
var (
thumbnail *ent.Entity
entity *ent.Entity
err error
)
if metadata[model.ThumbStatusMetadataKey] == model.ThumbStatusExist {
size := int64(0)
if m.state.LocalPolicyIDs[int(f.PolicyID)] {
thumbFile, err := os.Stat(f.SourceName + m.state.ThumbSuffix)
if err == nil {
size = thumbFile.Size()
}
m.l.Warning("Thumbnail file %s for file %d not found, use 0 size", f.SourceName+m.state.ThumbSuffix, f.ID)
}
// Insert thumbnail entity
thumbnail, err = m.insertEntity(tx, f.SourceName+m.state.ThumbSuffix, int(types.EntityTypeThumbnail), int(f.PolicyID), int(f.UserID), size)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to insert thumbnail entity: %w", err)
}
}
// Insert file version entity
entity, err = m.insertEntity(tx, f.SourceName, int(types.EntityTypeVersion), int(f.PolicyID), int(f.UserID), int64(f.Size))
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to insert file version entity: %w", err)
}
fname := f.Name
if _, ok := m.state.FileConflictRename[f.ID]; ok {
fname = m.state.FileConflictRename[f.ID]
}
stm := tx.File.Create().
SetCreatedAt(formatTime(f.CreatedAt)).
SetUpdatedAt(formatTime(f.UpdatedAt)).
SetName(fname).
SetRawID(int(f.ID) + m.state.LastFolderID).
SetOwnerID(int(f.UserID)).
SetSize(int64(f.Size)).
SetPrimaryEntity(entity.ID).
SetFileChildren(int(f.FolderID)).
SetType(int(types.FileTypeFile)).
SetStoragePoliciesID(int(f.PolicyID)).
AddEntities(entity)
if thumbnail != nil {
stm.AddEntities(thumbnail)
}
if _, err := stm.Save(ctx); err != nil {
_ = tx.Rollback()
if ent.IsConstraintError(err) {
if _, ok := m.state.FileConflictRename[f.ID]; ok {
return fmt.Errorf("file %d already exists, but new name is already in conflict rename map, please resolve this manually", f.ID)
}
m.l.Warning("File %d already exists, will retry with new name in next batch", f.ID)
m.state.FileConflictRename[f.ID] = fmt.Sprintf("%d_%s", f.ID, f.Name)
continue out
}
return fmt.Errorf("failed to create file %d: %w", f.ID, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
offset += batchSize
m.state.FileOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after file batch: %s", err)
} else {
m.l.Info("Saved migration state after processing this batch")
}
}
return nil
}
func (m *Migrator) insertEntity(tx *ent.Tx, source string, entityType, policyID, createdBy int, size int64) (*ent.Entity, error) {
// find existing one
entityKey := strconv.Itoa(policyID) + "+" + source
if existingId, ok := m.state.EntitySources[entityKey]; ok {
existing, err := tx.Entity.UpdateOneID(existingId).
AddReferenceCount(1).
Save(context.Background())
if err == nil {
return existing, nil
}
m.l.Warning("Failed to update existing entity %d: %s, fallback to create new one.", existingId, err)
}
// create new one
e, err := tx.Entity.Create().
SetSource(source).
SetType(entityType).
SetSize(size).
SetStoragePolicyEntities(policyID).
SetCreatedBy(createdBy).
SetReferenceCount(1).
Save(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to create new entity: %w", err)
}
m.state.EntitySources[entityKey] = e.ID
return e, nil
}

View File

@@ -0,0 +1,147 @@
package migrator
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
)
func (m *Migrator) migrateFolders() error {
m.l.Info("Migrating folders...")
batchSize := 1000
// Start from the saved offset if available
offset := m.state.FolderOffset
ctx := context.Background()
foldersCount := 0
if m.state.FolderIDs == nil {
m.state.FolderIDs = make(map[int]bool)
}
if offset > 0 {
m.l.Info("Resuming folder migration from offset %d", offset)
}
for {
m.l.Info("Migrating folders with offset %d", offset)
var folders []model.Folder
if err := model.DB.Limit(batchSize).Offset(offset).Find(&folders).Error; err != nil {
return fmt.Errorf("failed to list v3 folders: %w", err)
}
if len(folders) == 0 {
break
}
tx, err := m.v4client.Tx(ctx)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
batchFoldersCount := 0
for _, f := range folders {
if _, ok := m.state.UserIDs[int(f.OwnerID)]; !ok {
m.l.Warning("Owner ID %d not found, skipping folder %d", f.OwnerID, f.ID)
continue
}
isRoot := f.ParentID == nil
if isRoot {
f.Name = ""
} else if *f.ParentID == 0 {
m.l.Warning("Parent ID %d not found, skipping folder %d", *f.ParentID, f.ID)
continue
}
stm := tx.File.Create().
SetRawID(int(f.ID)).
SetType(int(types.FileTypeFolder)).
SetCreatedAt(formatTime(f.CreatedAt)).
SetUpdatedAt(formatTime(f.UpdatedAt)).
SetName(f.Name).
SetOwnerID(int(f.OwnerID))
if _, err := stm.Save(ctx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to create folder %d: %w", f.ID, err)
}
m.state.FolderIDs[int(f.ID)] = true
m.state.LastFolderID = int(f.ID)
foldersCount++
batchFoldersCount++
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
// Update the offset in state and save after each batch
offset += batchSize
m.state.FolderOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after folder batch: %s", err)
} else {
m.l.Info("Saved migration state after processing %d folders in this batch", batchFoldersCount)
}
}
m.l.Info("Successfully migrated %d folders", foldersCount)
return nil
}
func (m *Migrator) migrateFolderParent() error {
m.l.Info("Migrating folder parent...")
batchSize := 1000
offset := m.state.FolderParentOffset
ctx := context.Background()
for {
m.l.Info("Migrating folder parent with offset %d", offset)
var folderParents []model.Folder
if err := model.DB.Limit(batchSize).Offset(offset).Find(&folderParents).Error; err != nil {
return fmt.Errorf("failed to list v3 folder parents: %w", err)
}
if len(folderParents) == 0 {
break
}
tx, err := m.v4client.Tx(ctx)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
for _, f := range folderParents {
if f.ParentID != nil {
if _, ok := m.state.FolderIDs[int(*f.ParentID)]; !ok {
m.l.Warning("Folder ID %d not found, skipping folder parent %d", f.ID, f.ID)
continue
}
if _, err := tx.File.UpdateOneID(int(f.ID)).SetParentID(int(*f.ParentID)).Save(ctx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to update folder parent %d: %w", f.ID, err)
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
// Update the offset in state and save after each batch
offset += batchSize
m.state.FolderParentOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after folder parent batch: %s", err)
}
}
return nil
}

View File

@@ -0,0 +1,92 @@
package migrator
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/samber/lo"
)
func (m *Migrator) migrateGroup() error {
m.l.Info("Migrating groups...")
var groups []model.Group
if err := model.DB.Find(&groups).Error; err != nil {
return fmt.Errorf("failed to list v3 groups: %w", err)
}
for _, group := range groups {
cap := &boolset.BooleanSet{}
var (
opts model.GroupOption
policies []int
)
if err := json.Unmarshal([]byte(group.Options), &opts); err != nil {
return fmt.Errorf("failed to unmarshal options for group %q: %w", group.Name, err)
}
if err := json.Unmarshal([]byte(group.Policies), &policies); err != nil {
return fmt.Errorf("failed to unmarshal policies for group %q: %w", group.Name, err)
}
policies = lo.Filter(policies, func(id int, _ int) bool {
_, exist := m.state.PolicyIDs[id]
return exist
})
newOpts := &types.GroupSetting{
CompressSize: int64(opts.CompressSize),
DecompressSize: int64(opts.DecompressSize),
RemoteDownloadOptions: opts.Aria2Options,
SourceBatchSize: opts.SourceBatchSize,
RedirectedSource: opts.RedirectedSource,
Aria2BatchSize: opts.Aria2BatchSize,
MaxWalkedFiles: 100000,
TrashRetention: 7 * 24 * 3600,
}
boolset.Sets(map[types.GroupPermission]bool{
types.GroupPermissionIsAdmin: group.ID == 1,
types.GroupPermissionIsAnonymous: group.ID == 3,
types.GroupPermissionShareDownload: opts.ShareDownload,
types.GroupPermissionWebDAV: group.WebDAVEnabled,
types.GroupPermissionArchiveDownload: opts.ArchiveDownload,
types.GroupPermissionArchiveTask: opts.ArchiveTask,
types.GroupPermissionWebDAVProxy: opts.WebDAVProxy,
types.GroupPermissionRemoteDownload: opts.Aria2,
types.GroupPermissionAdvanceDelete: opts.AdvanceDelete,
types.GroupPermissionShare: group.ShareEnabled,
types.GroupPermissionRedirectedSource: opts.RedirectedSource,
}, cap)
stm := m.v4client.Group.Create().
SetRawID(int(group.ID)).
SetCreatedAt(formatTime(group.CreatedAt)).
SetUpdatedAt(formatTime(group.UpdatedAt)).
SetName(group.Name).
SetMaxStorage(int64(group.MaxStorage)).
SetSpeedLimit(group.SpeedLimit).
SetPermissions(cap).
SetSettings(newOpts)
if len(policies) > 0 {
stm.SetStoragePoliciesID(policies[0])
}
if _, err := stm.Save(context.Background()); err != nil {
return fmt.Errorf("failed to create group %q: %w", group.Name, err)
}
}
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting group ID sequence for postgres...")
m.v4client.Group.ExecContext(context.Background(), "SELECT SETVAL('groups_id_seq', (SELECT MAX(id) FROM groups))")
}
return nil
}

View File

@@ -0,0 +1,314 @@
package migrator
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/application/migrator/conf"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
)
// State stores the migration progress
type State struct {
PolicyIDs map[int]bool `json:"policy_ids,omitempty"`
LocalPolicyIDs map[int]bool `json:"local_policy_ids,omitempty"`
UserIDs map[int]bool `json:"user_ids,omitempty"`
FolderIDs map[int]bool `json:"folder_ids,omitempty"`
EntitySources map[string]int `json:"entity_sources,omitempty"`
LastFolderID int `json:"last_folder_id,omitempty"`
Step int `json:"step,omitempty"`
UserOffset int `json:"user_offset,omitempty"`
FolderOffset int `json:"folder_offset,omitempty"`
FileOffset int `json:"file_offset,omitempty"`
ShareOffset int `json:"share_offset,omitempty"`
GiftCodeOffset int `json:"gift_code_offset,omitempty"`
DirectLinkOffset int `json:"direct_link_offset,omitempty"`
WebdavOffset int `json:"webdav_offset,omitempty"`
StoragePackOffset int `json:"storage_pack_offset,omitempty"`
FileConflictRename map[uint]string `json:"file_conflict_rename,omitempty"`
FolderParentOffset int `json:"folder_parent_offset,omitempty"`
ThumbSuffix string `json:"thumb_suffix,omitempty"`
V3AvatarPath string `json:"v3_avatar_path,omitempty"`
}
// Step identifiers for migration phases
const (
StepInitial = 0
StepSchema = 1
StepSettings = 2
StepNode = 3
StepPolicy = 4
StepGroup = 5
StepUser = 6
StepFolders = 7
StepFolderParent = 8
StepFile = 9
StepShare = 10
StepDirectLink = 11
Step_CommunityPlaceholder1 = 12
Step_CommunityPlaceholder2 = 13
StepAvatar = 14
StepWebdav = 15
StepCompleted = 16
StateFileName = "migration_state.json"
)
type Migrator struct {
dep dependency.Dep
l logging.Logger
v4client *ent.Client
state *State
statePath string
}
func NewMigrator(dep dependency.Dep, v3ConfPath string) (*Migrator, error) {
m := &Migrator{
dep: dep,
l: dep.Logger(),
state: &State{
PolicyIDs: make(map[int]bool),
UserIDs: make(map[int]bool),
Step: StepInitial,
UserOffset: 0,
FolderOffset: 0,
},
}
// Determine state file path
configDir := filepath.Dir(v3ConfPath)
m.statePath = filepath.Join(configDir, StateFileName)
// Try to load existing state
if util.Exists(m.statePath) {
m.l.Info("Found existing migration state file, loading from %s", m.statePath)
if err := m.loadState(); err != nil {
return nil, fmt.Errorf("failed to load migration state: %w", err)
}
stepName := "unknown"
switch m.state.Step {
case StepInitial:
stepName = "initial"
case StepSchema:
stepName = "schema creation"
case StepSettings:
stepName = "settings migration"
case StepNode:
stepName = "node migration"
case StepPolicy:
stepName = "policy migration"
case StepGroup:
stepName = "group migration"
case StepUser:
stepName = "user migration"
case StepFolders:
stepName = "folders migration"
case StepCompleted:
stepName = "completed"
case StepWebdav:
stepName = "webdav migration"
case StepAvatar:
stepName = "avatar migration"
}
m.l.Info("Resumed migration from step %d (%s)", m.state.Step, stepName)
// Log batch information if applicable
if m.state.Step == StepUser && m.state.UserOffset > 0 {
m.l.Info("Will resume user migration from batch offset %d", m.state.UserOffset)
}
if m.state.Step == StepFolders && m.state.FolderOffset > 0 {
m.l.Info("Will resume folder migration from batch offset %d", m.state.FolderOffset)
}
}
err := conf.Init(m.dep.Logger(), v3ConfPath)
if err != nil {
return nil, err
}
err = model.Init()
if err != nil {
return nil, err
}
v4client, err := inventory.NewRawEntClient(m.l, m.dep.ConfigProvider())
if err != nil {
return nil, err
}
m.v4client = v4client
return m, nil
}
// saveState persists migration state to file
func (m *Migrator) saveState() error {
data, err := json.Marshal(m.state)
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
return os.WriteFile(m.statePath, data, 0644)
}
// loadState reads migration state from file
func (m *Migrator) loadState() error {
data, err := os.ReadFile(m.statePath)
if err != nil {
return fmt.Errorf("failed to read state file: %w", err)
}
return json.Unmarshal(data, m.state)
}
// updateStep updates current step and persists state
func (m *Migrator) updateStep(step int) error {
m.state.Step = step
return m.saveState()
}
func (m *Migrator) Migrate() error {
// Continue from the current step
if m.state.Step <= StepSchema {
m.l.Info("Creating basic v4 table schema...")
if err := m.v4client.Schema.Create(context.Background()); err != nil {
return fmt.Errorf("failed creating schema resources: %w", err)
}
if err := m.updateStep(StepSettings); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepSettings {
if err := m.migrateSettings(); err != nil {
return err
}
if err := m.updateStep(StepNode); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepNode {
if err := m.migrateNode(); err != nil {
return err
}
if err := m.updateStep(StepPolicy); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepPolicy {
allPolicyIDs, err := m.migratePolicy()
if err != nil {
return err
}
m.state.PolicyIDs = allPolicyIDs
if err := m.updateStep(StepGroup); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepGroup {
if err := m.migrateGroup(); err != nil {
return err
}
if err := m.updateStep(StepUser); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepUser {
if err := m.migrateUser(); err != nil {
m.saveState()
return err
}
// Reset user offset after completion
m.state.UserOffset = 0
if err := m.updateStep(StepFolders); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepFolders {
if err := m.migrateFolders(); err != nil {
m.saveState()
return err
}
// Reset folder offset after completion
m.state.FolderOffset = 0
if err := m.updateStep(StepFolderParent); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepFolderParent {
if err := m.migrateFolderParent(); err != nil {
return err
}
if err := m.updateStep(StepFile); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepFile {
if err := m.migrateFile(); err != nil {
return err
}
if err := m.updateStep(StepShare); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepShare {
if err := m.migrateShare(); err != nil {
return err
}
if err := m.updateStep(StepDirectLink); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepDirectLink {
if err := m.migrateDirectLink(); err != nil {
return err
}
if err := m.updateStep(StepAvatar); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepAvatar {
if err := migrateAvatars(m); err != nil {
return err
}
if err := m.updateStep(StepWebdav); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
if m.state.Step <= StepWebdav {
if err := m.migrateWebdav(); err != nil {
return err
}
if err := m.updateStep(StepCompleted); err != nil {
return fmt.Errorf("failed to update step: %w", err)
}
}
m.l.Info("Migration completed successfully")
return nil
}
func formatTime(t time.Time) time.Time {
newTime := time.UnixMilli(t.UnixMilli())
return newTime
}

View File

@@ -0,0 +1,288 @@
package dialects
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/jinzhu/gorm"
)
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db gorm.SQLCommon
DefaultForeignKeyNamer
}
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db gorm.SQLCommon) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
}
func (s *commonDialect) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
}
case reflect.Float32, reflect.Float64:
sqlType = "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
} else {
sqlType = "VARCHAR(65532)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "TIMESTAMP"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("BINARY(%d)", size)
} else {
sqlType = "BINARY(65532)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (commonDialect) DefaultValueStr() string {
return "DEFAULT VALUES"
}
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = keyNameRegex.ReplaceAllString(keyName, "_")
return keyName
}
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
return indexName, columnName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
type sqlite struct {
commonDialect
}
func init() {
gorm.RegisterDialect("sqlite", &sqlite{})
}
func (sqlite) GetName() string {
return "sqlite"
}
// Get Data Type for Sqlite Dialect
func (s *sqlite) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@@ -0,0 +1,39 @@
package model
import (
"github.com/jinzhu/gorm"
)
// File 文件
type File struct {
// 表字段
gorm.Model
Name string `gorm:"unique_index:idx_only_one"`
SourceName string `gorm:"type:text"`
UserID uint `gorm:"index:user_id;unique_index:idx_only_one"`
Size uint64
PicInfo string
FolderID uint `gorm:"index:folder_id;unique_index:idx_only_one"`
PolicyID uint
UploadSessionID *string `gorm:"index:session_id;unique_index:session_only_one"`
Metadata string `gorm:"type:text"`
// 关联模型
Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"`
// 数据库忽略字段
Position string `gorm:"-"`
MetadataSerialized map[string]string `gorm:"-"`
}
// Thumb related metadata
const (
ThumbStatusNotExist = ""
ThumbStatusExist = "exist"
ThumbStatusNotAvailable = "not_available"
ThumbStatusMetadataKey = "thumb_status"
ThumbSidecarMetadataKey = "thumb_sidecar"
ChecksumMetadataKey = "webdav_checksum"
)

View File

@@ -0,0 +1,18 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Folder 目录
type Folder struct {
// 表字段
gorm.Model
Name string `gorm:"unique_index:idx_only_one_name"`
ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"`
OwnerID uint `gorm:"index:owner_id"`
// 数据库忽略字段
Position string `gorm:"-"`
WebdavDstName string `gorm:"-"`
}

View File

@@ -0,0 +1,38 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Group 用户组模型
type Group struct {
gorm.Model
Name string
Policies string
MaxStorage uint64
ShareEnabled bool
WebDAVEnabled bool
SpeedLimit int
Options string `json:"-" gorm:"size:4294967295"`
// 数据库忽略字段
PolicyList []uint `gorm:"-"`
OptionsSerialized GroupOption `gorm:"-"`
}
// GroupOption 用户组其他配置
type GroupOption struct {
ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载
ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩
CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
SourceBatchSize int `json:"source_batch,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
AdvanceDelete bool `json:"advance_delete,omitempty"`
WebDAVProxy bool `json:"webdav_proxy,omitempty"`
}

View File

@@ -0,0 +1,91 @@
package model
import (
"fmt"
"time"
"github.com/jinzhu/gorm"
"github.com/cloudreve/Cloudreve/v4/application/migrator/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
)
// DB 数据库链接单例
var DB *gorm.DB
// Init 初始化 MySQL 链接
func Init() error {
var (
db *gorm.DB
err error
confDBType string = conf.DatabaseConfig.Type
)
// 兼容已有配置中的 "sqlite3" 配置项
if confDBType == "sqlite3" {
confDBType = "sqlite"
}
switch confDBType {
case "UNSET", "sqlite":
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite 数据库
db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile))
case "postgres":
db, err = gorm.Open(confDBType, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
conf.DatabaseConfig.Host,
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
conf.DatabaseConfig.Name,
conf.DatabaseConfig.Port))
case "mysql", "mssql":
var host string
if conf.DatabaseConfig.UnixSocket {
host = fmt.Sprintf("unix(%s)",
conf.DatabaseConfig.Host)
} else {
host = fmt.Sprintf("(%s:%d)",
conf.DatabaseConfig.Host,
conf.DatabaseConfig.Port)
}
db, err = gorm.Open(confDBType, fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local",
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
host,
conf.DatabaseConfig.Name,
conf.DatabaseConfig.Charset))
default:
return fmt.Errorf("unsupported database type %q", confDBType)
}
//db.SetLogger(util.Log())
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
// 处理表前缀
gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string {
return conf.DatabaseConfig.TablePrefix + defaultTableName
}
// Debug模式下输出所有 SQL 日志
db.LogMode(true)
//设置连接池
db.DB().SetMaxIdleConns(50)
if confDBType == "sqlite" || confDBType == "UNSET" {
db.DB().SetMaxOpenConns(1)
} else {
db.DB().SetMaxOpenConns(100)
}
//超时
db.DB().SetConnMaxLifetime(time.Second * 30)
DB = db
return nil
}

View File

@@ -0,0 +1,51 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Node 从机节点信息模型
type Node struct {
gorm.Model
Status NodeStatus // 节点状态
Name string // 节点别名
Type ModelType // 节点状态
Server string // 服务器地址
SlaveKey string `gorm:"type:text"` // 主->从 通信密钥
MasterKey string `gorm:"type:text"` // 从->主 通信密钥
Aria2Enabled bool // 是否支持用作离线下载节点
Aria2Options string `gorm:"type:text"` // 离线下载配置
Rank int // 负载均衡权重
// 数据库忽略字段
Aria2OptionsSerialized Aria2Option `gorm:"-"`
}
// Aria2Option 非公有的Aria2配置属性
type Aria2Option struct {
// RPC 服务器地址
Server string `json:"server,omitempty"`
// RPC 密钥
Token string `json:"token,omitempty"`
// 临时下载目录
TempPath string `json:"temp_path,omitempty"`
// 附加下载配置
Options string `json:"options,omitempty"`
// 下载监控间隔
Interval int `json:"interval,omitempty"`
// RPC API 请求超时
Timeout int `json:"timeout,omitempty"`
}
type NodeStatus int
type ModelType int
const (
NodeActive NodeStatus = iota
NodeSuspend
)
const (
SlaveNodeType ModelType = iota
MasterNodeType
)

View File

@@ -0,0 +1,62 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Policy 存储策略
type Policy struct {
// 表字段
gorm.Model
Name string
Type string
Server string
BucketName string
IsPrivate bool
BaseURL string
AccessKey string `gorm:"type:text"`
SecretKey string `gorm:"type:text"`
MaxSize uint64
AutoRename bool
DirNameRule string
FileNameRule string
IsOriginLinkEnable bool
Options string `gorm:"type:text"`
// 数据库忽略字段
OptionsSerialized PolicyOption `gorm:"-"`
MasterID string `gorm:"-"`
}
// PolicyOption 非公有的存储策略属性
type PolicyOption struct {
// Upyun访问Token
Token string `json:"token"`
// 允许的文件扩展名
FileType []string `json:"file_type"`
// MimeType
MimeType string `json:"mimetype"`
// OauthRedirect Oauth 重定向地址
OauthRedirect string `json:"od_redirect,omitempty"`
// OdProxy Onedrive 反代地址
OdProxy string `json:"od_proxy,omitempty"`
// OdDriver OneDrive 驱动器定位符
OdDriver string `json:"od_driver,omitempty"`
// Region 区域代码
Region string `json:"region,omitempty"`
// ServerSideEndpoint 服务端请求使用的 Endpoint为空时使用 Policy.Server 字段
ServerSideEndpoint string `json:"server_side_endpoint,omitempty"`
// 分片上传的分片大小
ChunkSize uint64 `json:"chunk_size,omitempty"`
// 分片上传时是否需要预留空间
PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"`
// 每秒对存储端的 API 请求上限
TPSLimit float64 `json:"tps_limit,omitempty"`
// 每秒 API 请求爆发上限
TPSLimitBurst int `json:"tps_limit_burst,omitempty"`
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY `
S3ForcePathStyle bool `json:"s3_path_style"`
// File extensions that support thumbnail generation using native policy API.
ThumbExts []string `json:"thumb_exts,omitempty"`
}

View File

@@ -0,0 +1,13 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Setting 系统设置模型
type Setting struct {
gorm.Model
Type string `gorm:"not null"`
Name string `gorm:"unique;not null;index:setting_key"`
Value string `gorm:"size:65535"`
}

View File

@@ -0,0 +1,27 @@
package model
import (
"time"
"github.com/jinzhu/gorm"
)
// Share 分享模型
type Share struct {
gorm.Model
Password string // 分享密码,空值为非加密分享
IsDir bool // 原始资源是否为目录
UserID uint // 创建用户ID
SourceID uint // 原始资源ID
Views int // 浏览数
Downloads int // 下载数
RemainDownloads int // 剩余下载配额,负值标识无限制
Expires *time.Time // 过期时间,空值表示无过期时间
PreviewEnabled bool // 是否允许直接预览
SourceName string `gorm:"index:source"` // 用于搜索的字段
// 数据库忽略字段
User User `gorm:"PRELOAD:false,association_autoupdate:false"`
File File `gorm:"PRELOAD:false,association_autoupdate:false"`
Folder Folder `gorm:"PRELOAD:false,association_autoupdate:false"`
}

View File

@@ -0,0 +1,16 @@
package model
import (
"github.com/jinzhu/gorm"
)
// SourceLink represent a shared file source link
type SourceLink struct {
gorm.Model
FileID uint // corresponding file ID
Name string // name of the file while creating the source link, for annotation
Downloads int // 下载数
// 关联模型
File File `gorm:"save_associations:false:false"`
}

View File

@@ -0,0 +1,23 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Tag 用户自定义标签
type Tag struct {
gorm.Model
Name string // 标签名
Icon string // 图标标识
Color string // 图标颜色
Type int // 标签类型(文件分类/目录直达)
Expression string `gorm:"type:text"` // 搜索表表达式/直达路径
UserID uint // 创建者ID
}
const (
// FileTagType 文件分类标签
FileTagType = iota
// DirectoryLinkType 目录快捷方式标签
DirectoryLinkType
)

View File

@@ -0,0 +1,16 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Task 任务模型
type Task struct {
gorm.Model
Status int // 任务状态
Type int // 任务类型
UserID uint // 发起者UID0表示为系统发起
Progress int // 进度
Error string `gorm:"type:text"` // 错误信息
Props string `gorm:"type:text"` // 任务属性
}

View File

@@ -0,0 +1,45 @@
package model
import (
"github.com/jinzhu/gorm"
)
const (
// Active 账户正常状态
Active = iota
// NotActivicated 未激活
NotActivicated
// Baned 被封禁
Baned
// OveruseBaned 超额使用被封禁
OveruseBaned
)
// User 用户模型
type User struct {
// 表字段
gorm.Model
Email string `gorm:"type:varchar(100);unique_index"`
Nick string `gorm:"size:50"`
Password string `json:"-"`
Status int
GroupID uint
Storage uint64
TwoFactor string
Avatar string
Options string `json:"-" gorm:"size:4294967295"`
Authn string `gorm:"size:4294967295"`
// 关联模型
Group Group `gorm:"save_associations:false:false"`
Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"`
// 数据库忽略字段
OptionsSerialized UserOption `gorm:"-"`
}
// UserOption 用户个性化配置字段
type UserOption struct {
ProfileOff bool `json:"profile_off,omitempty"`
PreferredTheme string `json:"preferred_theme,omitempty"`
}

View File

@@ -0,0 +1,16 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Webdav 应用账户
type Webdav struct {
gorm.Model
Name string // 应用名称
Password string `gorm:"unique_index:password_only_on"` // 应用密码
UserID uint `gorm:"unique_index:password_only_on"` // 用户ID
Root string `gorm:"type:text"` // 根目录
Readonly bool `gorm:"type:bool"` // 是否只读
UseProxy bool `gorm:"type:bool"` // 是否进行反代
}

View File

@@ -0,0 +1,89 @@
package migrator
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent/node"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
)
func (m *Migrator) migrateNode() error {
m.l.Info("Migrating nodes...")
var nodes []model.Node
if err := model.DB.Find(&nodes).Error; err != nil {
return fmt.Errorf("failed to list v3 nodes: %w", err)
}
for _, n := range nodes {
nodeType := node.TypeSlave
nodeStatus := node.StatusSuspended
if n.Type == model.MasterNodeType {
nodeType = node.TypeMaster
}
if n.Status == model.NodeActive {
nodeStatus = node.StatusActive
}
cap := &boolset.BooleanSet{}
settings := &types.NodeSetting{
Provider: types.DownloaderProviderAria2,
}
if n.Aria2Enabled {
boolset.Sets(map[types.NodeCapability]bool{
types.NodeCapabilityRemoteDownload: true,
}, cap)
aria2Options := &model.Aria2Option{}
if err := json.Unmarshal([]byte(n.Aria2Options), aria2Options); err != nil {
return fmt.Errorf("failed to unmarshal aria2 options: %w", err)
}
downloaderOptions := map[string]any{}
if aria2Options.Options != "" {
if err := json.Unmarshal([]byte(aria2Options.Options), &downloaderOptions); err != nil {
return fmt.Errorf("failed to unmarshal aria2 options: %w", err)
}
}
settings.Aria2Setting = &types.Aria2Setting{
Server: aria2Options.Server,
Token: aria2Options.Token,
Options: downloaderOptions,
TempPath: aria2Options.TempPath,
}
}
if n.Type == model.MasterNodeType {
boolset.Sets(map[types.NodeCapability]bool{
types.NodeCapabilityExtractArchive: true,
types.NodeCapabilityCreateArchive: true,
}, cap)
}
stm := m.v4client.Node.Create().
SetRawID(int(n.ID)).
SetCreatedAt(formatTime(n.CreatedAt)).
SetUpdatedAt(formatTime(n.UpdatedAt)).
SetName(n.Name).
SetType(nodeType).
SetStatus(nodeStatus).
SetServer(n.Server).
SetSlaveKey(n.SlaveKey).
SetCapabilities(cap).
SetSettings(settings).
SetWeight(n.Rank)
if err := stm.Exec(context.Background()); err != nil {
return fmt.Errorf("failed to create node %q: %w", n.Name, err)
}
}
return nil
}

View File

@@ -0,0 +1,192 @@
package migrator
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent/node"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/samber/lo"
)
func (m *Migrator) migratePolicy() (map[int]bool, error) {
m.l.Info("Migrating storage policies...")
var policies []model.Policy
if err := model.DB.Find(&policies).Error; err != nil {
return nil, fmt.Errorf("failed to list v3 storage policies: %w", err)
}
if m.state.LocalPolicyIDs == nil {
m.state.LocalPolicyIDs = make(map[int]bool)
}
if m.state.PolicyIDs == nil {
m.state.PolicyIDs = make(map[int]bool)
}
m.l.Info("Found %d v3 storage policies to be migrated.", len(policies))
// get thumb proxy settings
var (
thumbProxySettings []model.Setting
thumbProxyEnabled bool
thumbProxyPolicy []int
)
if err := model.DB.Where("name in (?)", []string{"thumb_proxy_enabled", "thumb_proxy_policy"}).Find(&thumbProxySettings).Error; err != nil {
m.l.Warning("Failed to list v3 thumb proxy settings: %w", err)
}
tx, err := m.v4client.Tx(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to start transaction: %w", err)
}
for _, s := range thumbProxySettings {
if s.Name == "thumb_proxy_enabled" {
thumbProxyEnabled = setting.IsTrueValue(s.Value)
} else if s.Name == "thumb_proxy_policy" {
if err := json.Unmarshal([]byte(s.Value), &thumbProxyPolicy); err != nil {
m.l.Warning("Failed to unmarshal v3 thumb proxy policy: %w", err)
}
}
}
for _, policy := range policies {
m.l.Info("Migrating storage policy %q...", policy.Name)
if err := json.Unmarshal([]byte(policy.Options), &policy.OptionsSerialized); err != nil {
return nil, fmt.Errorf("failed to unmarshal options for policy %q: %w", policy.Name, err)
}
settings := &types.PolicySetting{
Token: policy.OptionsSerialized.Token,
FileType: policy.OptionsSerialized.FileType,
OauthRedirect: policy.OptionsSerialized.OauthRedirect,
OdDriver: policy.OptionsSerialized.OdDriver,
Region: policy.OptionsSerialized.Region,
ServerSideEndpoint: policy.OptionsSerialized.ServerSideEndpoint,
ChunkSize: int64(policy.OptionsSerialized.ChunkSize),
TPSLimit: policy.OptionsSerialized.TPSLimit,
TPSLimitBurst: policy.OptionsSerialized.TPSLimitBurst,
S3ForcePathStyle: policy.OptionsSerialized.S3ForcePathStyle,
ThumbExts: policy.OptionsSerialized.ThumbExts,
}
if policy.Type == types.PolicyTypeOd {
settings.ThumbSupportAllExts = true
} else {
switch policy.Type {
case types.PolicyTypeCos:
settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heif", "heic"}
case types.PolicyTypeOss:
settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heic", "tiff", "avif"}
case types.PolicyTypeUpyun:
settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"}
case types.PolicyTypeQiniu:
settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "tiff", "avif", "psd"}
case types.PolicyTypeRemote:
settings.ThumbExts = []string{"png", "jpg", "jpeg", "gif"}
}
}
if policy.Type != types.PolicyTypeOd && policy.BaseURL != "" {
settings.CustomProxy = true
settings.ProxyServer = policy.BaseURL
} else if policy.OptionsSerialized.OdProxy != "" {
settings.CustomProxy = true
settings.ProxyServer = policy.OptionsSerialized.OdProxy
}
if policy.DirNameRule == "" {
policy.DirNameRule = "uploads/{uid}/{path}"
}
if policy.Type == types.PolicyTypeCos {
settings.ChunkSize = 1024 * 1024 * 25
}
if thumbProxyEnabled && lo.Contains(thumbProxyPolicy, int(policy.ID)) {
settings.ThumbGeneratorProxy = true
}
mustContain := []string{"{randomkey16}", "{randomkey8}", "{uuid}"}
hasRandomElement := false
for _, c := range mustContain {
if strings.Contains(policy.FileNameRule, c) {
hasRandomElement = true
break
}
}
if !hasRandomElement {
policy.FileNameRule = "{uid}_{randomkey8}_{originname}"
m.l.Warning("Storage policy %q has no random element in file name rule, using default file name rule.", policy.Name)
}
stm := tx.StoragePolicy.Create().
SetRawID(int(policy.ID)).
SetCreatedAt(formatTime(policy.CreatedAt)).
SetUpdatedAt(formatTime(policy.UpdatedAt)).
SetName(policy.Name).
SetType(policy.Type).
SetServer(policy.Server).
SetBucketName(policy.BucketName).
SetIsPrivate(policy.IsPrivate).
SetAccessKey(policy.AccessKey).
SetSecretKey(policy.SecretKey).
SetMaxSize(int64(policy.MaxSize)).
SetDirNameRule(policy.DirNameRule).
SetFileNameRule(policy.FileNameRule).
SetSettings(settings)
if policy.Type == types.PolicyTypeRemote {
m.l.Info("Storage policy %q is remote, creating node for it...", policy.Name)
bs := &boolset.BooleanSet{}
n, err := tx.Node.Create().
SetName(policy.Name).
SetStatus(node.StatusActive).
SetServer(policy.Server).
SetSlaveKey(policy.SecretKey).
SetType(node.TypeSlave).
SetCapabilities(bs).
SetSettings(&types.NodeSetting{
Provider: types.DownloaderProviderAria2,
}).
Save(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to create node for storage policy %q: %w", policy.Name, err)
}
stm.SetNodeID(n.ID)
}
if _, err := stm.Save(context.Background()); err != nil {
return nil, fmt.Errorf("failed to create storage policy %q: %w", policy.Name, err)
}
m.state.PolicyIDs[int(policy.ID)] = true
if policy.Type == types.PolicyTypeLocal {
m.state.LocalPolicyIDs[int(policy.ID)] = true
}
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting storage policy ID sequence for postgres...")
m.v4client.StoragePolicy.ExecContext(context.Background(), "SELECT SETVAL('storage_policies_id_seq', (SELECT MAX(id) FROM storage_policies))")
}
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting node ID sequence for postgres...")
m.v4client.Node.ExecContext(context.Background(), "SELECT SETVAL('nodes_id_seq', (SELECT MAX(id) FROM nodes))")
}
return m.state.PolicyIDs, nil
}

View File

@@ -0,0 +1,213 @@
package migrator
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/conf"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
)
// TODO:
// 1. Policy thumb proxy migration
type (
settignMigrator func(allSettings map[string]string, name, value string) ([]settingMigrated, error)
settingMigrated struct {
name string
value string
}
// PackProduct 容量包商品
PackProduct struct {
ID int64 `json:"id"`
Name string `json:"name"`
Size uint64 `json:"size"`
Time int64 `json:"time"`
Price int `json:"price"`
Score int `json:"score"`
}
GroupProducts struct {
ID int64 `json:"id"`
Name string `json:"name"`
GroupID uint `json:"group_id"`
Time int64 `json:"time"`
Price int `json:"price"`
Score int `json:"score"`
Des []string `json:"des"`
Highlight bool `json:"highlight"`
}
)
var noopMigrator = func(allSettings map[string]string, name, value string) ([]settingMigrated, error) {
return nil, nil
}
var migrators = map[string]settignMigrator{
"siteKeywords": noopMigrator,
"over_used_template": noopMigrator,
"download_timeout": noopMigrator,
"preview_timeout": noopMigrator,
"doc_preview_timeout": noopMigrator,
"slave_node_retry": noopMigrator,
"slave_ping_interval": noopMigrator,
"slave_recover_interval": noopMigrator,
"slave_transfer_timeout": noopMigrator,
"onedrive_monitor_timeout": noopMigrator,
"onedrive_source_timeout": noopMigrator,
"share_download_session_timeout": noopMigrator,
"onedrive_callback_check": noopMigrator,
"mail_activation_template": noopMigrator,
"mail_reset_pwd_template": noopMigrator,
"appid": noopMigrator,
"appkey": noopMigrator,
"wechat_enabled": noopMigrator,
"wechat_appid": noopMigrator,
"wechat_mchid": noopMigrator,
"wechat_serial_no": noopMigrator,
"wechat_api_key": noopMigrator,
"wechat_pk_content": noopMigrator,
"hot_share_num": noopMigrator,
"defaultTheme": noopMigrator,
"theme_options": noopMigrator,
"max_worker_num": noopMigrator,
"max_parallel_transfer": noopMigrator,
"secret_key": noopMigrator,
"avatar_size_m": noopMigrator,
"avatar_size_s": noopMigrator,
"home_view_method": noopMigrator,
"share_view_method": noopMigrator,
"cron_recycle_upload_session": noopMigrator,
"captcha_type": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) {
if value == "tcaptcha" {
value = "normal"
}
return []settingMigrated{
{
name: "captcha_type",
value: value,
},
}, nil
},
"captcha_TCaptcha_CaptchaAppId": noopMigrator,
"captcha_TCaptcha_AppSecretKey": noopMigrator,
"captcha_TCaptcha_SecretId": noopMigrator,
"captcha_TCaptcha_SecretKey": noopMigrator,
"thumb_file_suffix": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) {
return []settingMigrated{
{
name: "thumb_entity_suffix",
value: value,
},
}, nil
},
"thumb_max_src_size": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) {
return []settingMigrated{
{
name: "thumb_music_cover_max_size",
value: value,
},
{
name: "thumb_libreoffice_max_size",
value: value,
},
{
name: "thumb_ffmpeg_max_size",
value: value,
},
{
name: "thumb_vips_max_size",
value: value,
},
{
name: "thumb_builtin_max_size",
value: value,
},
}, nil
},
"initial_files": noopMigrator,
"office_preview_service": noopMigrator,
"phone_required": noopMigrator,
"phone_enabled": noopMigrator,
"wopi_session_timeout": func(allSettings map[string]string, name, value string) ([]settingMigrated, error) {
return []settingMigrated{
{
name: "viewer_session_timeout",
value: value,
},
}, nil
},
"custom_payment_enabled": noopMigrator,
"custom_payment_endpoint": noopMigrator,
"custom_payment_secret": noopMigrator,
"custom_payment_name": noopMigrator,
}
func (m *Migrator) migrateSettings() error {
m.l.Info("Migrating settings...")
// 1. List all settings
var settings []model.Setting
if err := model.DB.Find(&settings).Error; err != nil {
return fmt.Errorf("failed to list v3 settings: %w", err)
}
m.l.Info("Found %d v3 setting pairs to be migrated.", len(settings))
allSettings := make(map[string]string)
for _, s := range settings {
allSettings[s.Name] = s.Value
}
migratedSettings := make([]settingMigrated, 0)
for _, s := range settings {
if s.Name == "thumb_file_suffix" {
m.state.ThumbSuffix = s.Value
}
if s.Name == "avatar_path" {
m.state.V3AvatarPath = s.Value
}
migrator, ok := migrators[s.Name]
if ok {
newSettings, err := migrator(allSettings, s.Name, s.Value)
if err != nil {
return fmt.Errorf("failed to migrate setting %q: %w", s.Name, err)
}
migratedSettings = append(migratedSettings, newSettings...)
} else {
migratedSettings = append(migratedSettings, settingMigrated{
name: s.Name,
value: s.Value,
})
}
}
tx, err := m.v4client.Tx(context.Background())
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
// Insert hash_id_salt
if conf.SystemConfig.HashIDSalt != "" {
if err := tx.Setting.Create().SetName("hash_id_salt").SetValue(conf.SystemConfig.HashIDSalt).Exec(context.Background()); err != nil {
if err := tx.Rollback(); err != nil {
return fmt.Errorf("failed to rollback transaction: %w", err)
}
return fmt.Errorf("failed to create setting hash_id_salt: %w", err)
}
} else {
return fmt.Errorf("hash ID salt is not set, please set it from v3 conf file")
}
for _, s := range migratedSettings {
if err := tx.Setting.Create().SetName(s.name).SetValue(s.value).Exec(context.Background()); err != nil {
if err := tx.Rollback(); err != nil {
return fmt.Errorf("failed to rollback transaction: %w", err)
}
return fmt.Errorf("failed to create setting %q: %w", s.name, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}

View File

@@ -0,0 +1,102 @@
package migrator
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
)
func (m *Migrator) migrateShare() error {
m.l.Info("Migrating shares...")
batchSize := 1000
offset := m.state.ShareOffset
ctx := context.Background()
if offset > 0 {
m.l.Info("Resuming share migration from offset %d", offset)
}
for {
m.l.Info("Migrating shares with offset %d", offset)
var shares []model.Share
if err := model.DB.Limit(batchSize).Offset(offset).Find(&shares).Error; err != nil {
return fmt.Errorf("failed to list v3 shares: %w", err)
}
if len(shares) == 0 {
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting share ID sequence for postgres...")
m.v4client.Share.ExecContext(ctx, "SELECT SETVAL('shares_id_seq', (SELECT MAX(id) FROM shares))")
}
break
}
tx, err := m.v4client.Tx(ctx)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
for _, s := range shares {
sourceId := int(s.SourceID)
if !s.IsDir {
sourceId += m.state.LastFolderID
}
// check if file exists
_, err = tx.File.Query().Where(file.ID(sourceId)).First(ctx)
if err != nil {
m.l.Warning("File %d not found, skipping share %d", sourceId, s.ID)
continue
}
// check if user exist
if _, ok := m.state.UserIDs[int(s.UserID)]; !ok {
m.l.Warning("User %d not found, skipping share %d", s.UserID, s.ID)
continue
}
stm := tx.Share.Create().
SetCreatedAt(formatTime(s.CreatedAt)).
SetUpdatedAt(formatTime(s.UpdatedAt)).
SetViews(s.Views).
SetRawID(int(s.ID)).
SetDownloads(s.Downloads).
SetFileID(sourceId).
SetUserID(int(s.UserID))
if s.Password != "" {
stm.SetPassword(s.Password)
}
if s.Expires != nil {
stm.SetNillableExpires(s.Expires)
}
if s.RemainDownloads >= 0 {
stm.SetRemainDownloads(s.RemainDownloads)
}
if _, err := stm.Save(ctx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to create share %d: %w", s.ID, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
offset += batchSize
m.state.ShareOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after share batch: %s", err)
} else {
m.l.Info("Saved migration state after processing this batch")
}
}
return nil
}

View File

@@ -0,0 +1,109 @@
package migrator
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/ent/user"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
)
func (m *Migrator) migrateUser() error {
m.l.Info("Migrating users...")
batchSize := 1000
// Start from the saved offset if available
offset := m.state.UserOffset
ctx := context.Background()
if m.state.UserIDs == nil {
m.state.UserIDs = make(map[int]bool)
}
// If we're resuming, load existing user IDs
if len(m.state.UserIDs) > 0 {
m.l.Info("Resuming user migration from offset %d, %d users already migrated", offset, len(m.state.UserIDs))
}
for {
m.l.Info("Migrating users with offset %d", offset)
var users []model.User
if err := model.DB.Limit(batchSize).Offset(offset).Find(&users).Error; err != nil {
return fmt.Errorf("failed to list v3 users: %w", err)
}
if len(users) == 0 {
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting user ID sequence for postgres...")
m.v4client.User.ExecContext(ctx, "SELECT SETVAL('users_id_seq', (SELECT MAX(id) FROM users))")
}
break
}
tx, err := m.v4client.Tx(context.Background())
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
for _, u := range users {
userStatus := user.StatusActive
switch u.Status {
case model.Active:
userStatus = user.StatusActive
case model.NotActivicated:
userStatus = user.StatusInactive
case model.Baned:
userStatus = user.StatusManualBanned
case model.OveruseBaned:
userStatus = user.StatusSysBanned
}
setting := &types.UserSetting{
VersionRetention: true,
VersionRetentionMax: 10,
}
stm := tx.User.Create().
SetRawID(int(u.ID)).
SetCreatedAt(formatTime(u.CreatedAt)).
SetUpdatedAt(formatTime(u.UpdatedAt)).
SetEmail(u.Email).
SetNick(u.Nick).
SetStatus(userStatus).
SetStorage(int64(u.Storage)).
SetGroupID(int(u.GroupID)).
SetSettings(setting).
SetPassword(u.Password)
if u.TwoFactor != "" {
stm.SetTwoFactorSecret(u.TwoFactor)
}
if u.Avatar != "" {
stm.SetAvatar(u.Avatar)
}
if _, err := stm.Save(ctx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to create user %d: %w", u.ID, err)
}
m.state.UserIDs[int(u.ID)] = true
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
// Update the offset in state and save after each batch
offset += batchSize
m.state.UserOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after user batch: %s", err)
} else {
m.l.Info("Saved migration state after processing %d users", offset)
}
}
return nil
}

View File

@@ -0,0 +1,93 @@
package migrator
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/migrator/model"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
)
func (m *Migrator) migrateWebdav() error {
m.l.Info("Migrating webdav accounts...")
batchSize := 1000
offset := m.state.WebdavOffset
ctx := context.Background()
if m.state.WebdavOffset > 0 {
m.l.Info("Resuming webdav migration from offset %d", offset)
}
for {
m.l.Info("Migrating webdav accounts with offset %d", offset)
var webdavAccounts []model.Webdav
if err := model.DB.Limit(batchSize).Offset(offset).Find(&webdavAccounts).Error; err != nil {
return fmt.Errorf("failed to list v3 webdav accounts: %w", err)
}
if len(webdavAccounts) == 0 {
if m.dep.ConfigProvider().Database().Type == conf.PostgresDB {
m.l.Info("Resetting webdav account ID sequence for postgres...")
m.v4client.DavAccount.ExecContext(ctx, "SELECT SETVAL('dav_accounts_id_seq', (SELECT MAX(id) FROM dav_accounts))")
}
break
}
tx, err := m.v4client.Tx(ctx)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to start transaction: %w", err)
}
for _, webdavAccount := range webdavAccounts {
if _, ok := m.state.UserIDs[int(webdavAccount.UserID)]; !ok {
m.l.Warning("User %d not found, skipping webdav account %d", webdavAccount.UserID, webdavAccount.ID)
continue
}
props := types.DavAccountProps{}
options := boolset.BooleanSet{}
if webdavAccount.Readonly {
boolset.Set(int(types.DavAccountReadOnly), true, &options)
}
if webdavAccount.UseProxy {
boolset.Set(int(types.DavAccountProxy), true, &options)
}
stm := tx.DavAccount.Create().
SetCreatedAt(formatTime(webdavAccount.CreatedAt)).
SetUpdatedAt(formatTime(webdavAccount.UpdatedAt)).
SetRawID(int(webdavAccount.ID)).
SetName(webdavAccount.Name).
SetURI("cloudreve://my" + webdavAccount.Root).
SetPassword(webdavAccount.Password).
SetProps(&props).
SetOptions(&options).
SetOwnerID(int(webdavAccount.UserID))
if _, err := stm.Save(ctx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to create webdav account %d: %w", webdavAccount.ID, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
offset += batchSize
m.state.WebdavOffset = offset
if err := m.saveState(); err != nil {
m.l.Warning("Failed to save state after webdav batch: %s", err)
} else {
m.l.Info("Saved migration state after processing this batch")
}
}
return nil
}

View File

@@ -0,0 +1,432 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package embed provides access to files embedded in the running Go program.
//
// Go source files that import "embed" can use the //go:embed directive
// to initialize a variable of type string, []byte, or FS with the contents of
// files read from the package directory or subdirectories at compile time.
//
// For example, here are three ways to embed a file named hello.txt
// and then print its contents at run time.
//
// Embedding one file into a string:
//
// import _ "embed"
//
// //go:embed hello.txt
// var s string
// print(s)
//
// Embedding one file into a slice of bytes:
//
// import _ "embed"
//
// //go:embed hello.txt
// var b []byte
// print(string(b))
//
// Embedded one or more files into a file system:
//
// import "embed"
//
// //go:embed hello.txt
// var f embed.FS
// data, _ := f.ReadFile("hello.txt")
// print(string(data))
//
// # Directives
//
// A //go:embed directive above a variable declaration specifies which files to embed,
// using one or more path.Match patterns.
//
// The directive must immediately precede a line containing the declaration of a single variable.
// Only blank lines and // line comments are permitted between the directive and the declaration.
//
// The type of the variable must be a string type, or a slice of a byte type,
// or FS (or an alias of FS).
//
// For example:
//
// package server
//
// import "embed"
//
// // content holds our static web server content.
// //go:embed image/* template/*
// //go:embed html/index.html
// var content embed.FS
//
// The Go build system will recognize the directives and arrange for the declared variable
// (in the example above, content) to be populated with the matching files from the file system.
//
// The //go:embed directive accepts multiple space-separated patterns for
// brevity, but it can also be repeated, to avoid very long lines when there are
// many patterns. The patterns are interpreted relative to the package directory
// containing the source file. The path separator is a forward slash, even on
// Windows systems. Patterns may not contain . or .. or empty path elements,
// nor may they begin or end with a slash. To match everything in the current
// directory, use * instead of .. To allow for naming files with spaces in
// their names, patterns can be written as Go double-quoted or back-quoted
// string literals.
//
// If a pattern names a directory, all files in the subtree rooted at that directory are
// embedded (recursively), except that files with names beginning with . or _
// are excluded. So the variable in the above example is almost equivalent to:
//
// // content is our static web server content.
// //go:embed image template html/index.html
// var content embed.FS
//
// The difference is that image/* embeds image/.tempfile while image does not.
// Neither embeds image/dir/.tempfile.
//
// If a pattern begins with the prefix all:, then the rule for walking directories is changed
// to include those files beginning with . or _. For example, all:image embeds
// both image/.tempfile and image/dir/.tempfile.
//
// The //go:embed directive can be used with both exported and unexported variables,
// depending on whether the package wants to make the data available to other packages.
// It can only be used with variables at package scope, not with local variables.
//
// Patterns must not match files outside the package's module, such as .git/* or symbolic links.
// Patterns must not match files whose names include the special punctuation characters " * < > ? ` ' | / \ and :.
// Matches for empty directories are ignored. After that, each pattern in a //go:embed line
// must match at least one file or non-empty directory.
//
// If any patterns are invalid or have invalid matches, the build will fail.
//
// # Strings and Bytes
//
// The //go:embed line for a variable of type string or []byte can have only a single pattern,
// and that pattern can match only a single file. The string or []byte is initialized with
// the contents of that file.
//
// The //go:embed directive requires importing "embed", even when using a string or []byte.
// In source files that don't refer to embed.FS, use a blank import (import _ "embed").
//
// # File Systems
//
// For embedding a single file, a variable of type string or []byte is often best.
// The FS type enables embedding a tree of files, such as a directory of static
// web server content, as in the example above.
//
// FS implements the io/fs package's FS interface, so it can be used with any package that
// understands file systems, including net/http, text/template, and html/template.
//
// For example, given the content variable in the example above, we can write:
//
// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(content))))
//
// template.ParseFS(content, "*.tmpl")
//
// # Tools
//
// To support tools that analyze Go packages, the patterns found in //go:embed lines
// are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns,
// and XTestEmbedPatterns fields in the “go help list” output.
package statics
import (
"errors"
"io"
"io/fs"
"time"
)
// An FS is a read-only collection of files, usually initialized with a //go:embed directive.
// When declared without a //go:embed directive, an FS is an empty file system.
//
// An FS is a read-only value, so it is safe to use from multiple goroutines
// simultaneously and also safe to assign values of type FS to each other.
//
// FS implements fs.FS, so it can be used with any package that understands
// file system interfaces, including net/http, text/template, and html/template.
//
// See the package documentation for more details about initializing an FS.
type FS struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
//
// The files list is sorted by name but not by simple string comparison.
// Instead, each file's name takes the form "dir/elem" or "dir/elem/".
// The optional trailing slash indicates that the file is itself a directory.
// The files list is sorted first by dir (if dir is missing, it is taken to be ".")
// and then by base, so this list of files:
//
// p
// q/
// q/r
// q/s/
// q/s/t
// q/s/u
// q/v
// w
//
// is actually sorted as:
//
// p # dir=. elem=p
// q/ # dir=. elem=q
// w/ # dir=. elem=w
// q/r # dir=q elem=r
// q/s/ # dir=q elem=s
// q/v # dir=q elem=v
// q/s/t # dir=q/s elem=t
// q/s/u # dir=q/s elem=u
//
// This order brings directory contents together in contiguous sections
// of the list, allowing a directory read to use binary search to find
// the relevant sequence of entries.
files *[]file
}
// split splits the name into dir and elem as described in the
// comment in the FS struct above. isDir reports whether the
// final trailing slash was present, indicating that name is a directory.
func split(name string) (dir, elem string, isDir bool) {
if name[len(name)-1] == '/' {
isDir = true
name = name[:len(name)-1]
}
i := len(name) - 1
for i >= 0 && name[i] != '/' {
i--
}
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
// trimSlash trims a trailing slash from name, if present,
// returning the possibly shortened name.
func trimSlash(name string) string {
if len(name) > 0 && name[len(name)-1] == '/' {
return name[:len(name)-1]
}
return name
}
var (
_ fs.ReadDirFS = FS{}
_ fs.ReadFileFS = FS{}
)
// A file is a single file in the FS.
// It implements fs.FileInfo and fs.DirEntry.
type file struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
name string
data string
hash [16]byte // truncated SHA256 hash
}
var (
_ fs.FileInfo = (*file)(nil)
_ fs.DirEntry = (*file)(nil)
)
func (f *file) Name() string { _, elem, _ := split(f.name); return elem }
func (f *file) Size() int64 { return int64(len(f.data)) }
func (f *file) ModTime() time.Time { return time.Time{} }
func (f *file) IsDir() bool { _, _, isDir := split(f.name); return isDir }
func (f *file) Sys() any { return nil }
func (f *file) Type() fs.FileMode { return f.Mode().Type() }
func (f *file) Info() (fs.FileInfo, error) { return f, nil }
func (f *file) Mode() fs.FileMode {
if f.IsDir() {
return fs.ModeDir | 0555
}
return 0444
}
// dotFile is a file for the root directory,
// which is omitted from the files list in a FS.
var dotFile = &file{name: "./"}
// lookup returns the named file, or nil if it is not present.
func (f FS) lookup(name string) *file {
if !fs.ValidPath(name) {
// The compiler should never emit a file with an invalid name,
// so this check is not strictly necessary (if name is invalid,
// we shouldn't find a match below), but it's a good backstop anyway.
return nil
}
if name == "." {
return dotFile
}
if f.files == nil {
return nil
}
// Binary search to find where name would be in the list,
// and then check if name is at that position.
dir, elem, _ := split(name)
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, ielem, _ := split(files[i].name)
return idir > dir || idir == dir && ielem >= elem
})
if i < len(files) && trimSlash(files[i].name) == name {
return &files[i]
}
return nil
}
// readDir returns the list of files corresponding to the directory dir.
func (f FS) readDir(dir string) []file {
if f.files == nil {
return nil
}
// Binary search to find where dir starts and ends in the list
// and then return that slice of the list.
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, _, _ := split(files[i].name)
return idir >= dir
})
j := sortSearch(len(files), func(j int) bool {
jdir, _, _ := split(files[j].name)
return jdir > dir
})
return files[i:j]
}
// Open opens the named file for reading and returns it as an fs.File.
//
// The returned file implements io.Seeker when the file is not a directory.
func (f FS) Open(name string) (fs.File, error) {
file := f.lookup(name)
if file == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if file.IsDir() {
return &openDir{file, f.readDir(name), 0}, nil
}
return &openFile{file, 0}, nil
}
// ReadDir reads and returns the entire named directory.
func (f FS) ReadDir(name string) ([]fs.DirEntry, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
dir, ok := file.(*openDir)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("not a directory")}
}
list := make([]fs.DirEntry, len(dir.files))
for i := range list {
list[i] = &dir.files[i]
}
return list, nil
}
// ReadFile reads and returns the content of the named file.
func (f FS) ReadFile(name string) ([]byte, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
ofile, ok := file.(*openFile)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("is a directory")}
}
return []byte(ofile.f.data), nil
}
// An openFile is a regular file open for reading.
type openFile struct {
f *file // the file itself
offset int64 // current read offset
}
var (
_ io.Seeker = (*openFile)(nil)
)
func (f *openFile) Close() error { return nil }
func (f *openFile) Stat() (fs.FileInfo, error) { return f.f, nil }
func (f *openFile) Read(b []byte) (int, error) {
if f.offset >= int64(len(f.f.data)) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.f.name, Err: fs.ErrInvalid}
}
n := copy(b, f.f.data[f.offset:])
f.offset += int64(n)
return n, nil
}
func (f *openFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case 0:
// offset += 0
case 1:
offset += f.offset
case 2:
offset += int64(len(f.f.data))
}
if offset < 0 || offset > int64(len(f.f.data)) {
return 0, &fs.PathError{Op: "seek", Path: f.f.name, Err: fs.ErrInvalid}
}
f.offset = offset
return offset, nil
}
// An openDir is a directory open for reading.
type openDir struct {
f *file // the directory file itself
files []file // the directory contents
offset int // the read offset, an index into the files slice
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.f, nil }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.f.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = &d.files[d.offset+i]
}
d.offset += n
return list, nil
}
// sortSearch is like sort.Search, avoiding an import.
func sortSearch(n int, f func(int) bool) int {
// Define f(-1) == false and f(n) == true.
// Invariant: f(i-1) == false, f(j) == true.
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !f(h) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
return i
}

View File

@@ -0,0 +1,206 @@
package statics
import (
"archive/zip"
"bufio"
"crypto/sha256"
_ "embed"
"encoding/json"
"fmt"
"io"
"io/fs"
"net/http"
"path/filepath"
"sort"
"strings"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-contrib/static"
)
const StaticFolder = "statics"
//go:embed assets.zip
var zipContent string
type GinFS struct {
FS http.FileSystem
}
type version struct {
Name string `json:"name"`
Version string `json:"version"`
}
// Open 打开文件
func (b *GinFS) Open(name string) (http.File, error) {
return b.FS.Open(name)
}
// Exists 文件是否存在
func (b *GinFS) Exists(prefix string, filepath string) bool {
if _, err := b.FS.Open(filepath); err != nil {
return false
}
return true
}
// NewServerStaticFS 初始化静态资源文件
func NewServerStaticFS(l logging.Logger, statics fs.FS, isPro bool) (static.ServeFileSystem, error) {
var staticFS static.ServeFileSystem
if util.Exists(util.DataPath(StaticFolder)) {
l.Info("Folder with %q already exists, it will be used to serve static files.", util.DataPath(StaticFolder))
staticFS = static.LocalFile(util.DataPath(StaticFolder), false)
} else {
// 初始化静态资源
embedFS, err := fs.Sub(statics, "assets/build")
if err != nil {
return nil, fmt.Errorf("failed to initialize static resources: %w", err)
}
staticFS = &GinFS{
FS: http.FS(embedFS),
}
}
// 检查静态资源的版本
f, err := staticFS.Open("version.json")
if err != nil {
l.Warning("Missing version identifier file in static resources, please delete \"statics\" folder and rebuild it.")
return staticFS, nil
}
b, err := io.ReadAll(f)
if err != nil {
l.Warning("Failed to read version identifier file in static resources, please delete \"statics\" folder and rebuild it.")
return staticFS, nil
}
var v version
if err := json.Unmarshal(b, &v); err != nil {
l.Warning("Failed to parse version identifier file in static resources: %s", err)
return staticFS, nil
}
staticName := "cloudreve-frontend"
if isPro {
staticName += "-pro"
}
if v.Name != staticName {
l.Error("Static resource version mismatch, please delete \"statics\" folder and rebuild it.")
}
if v.Version != constants.BackendVersion {
l.Error("Static resource version mismatch [Current %s, Desired: %s]please delete \"statics\" folder and rebuild it.", v.Version, constants.BackendVersion)
}
return staticFS, nil
}
func NewStaticFS(l logging.Logger) fs.FS {
zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent)))
if err != nil {
l.Panic("Static resource is not a valid zip file: %s", err)
}
var files []file
err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return fmt.Errorf("cannot walk into %q: %w", path, err)
}
if path == "." {
return nil
}
var f file
if d.IsDir() {
f.name = path + "/"
} else {
f.name = path
rc, err := zipReader.Open(path)
if err != nil {
return fmt.Errorf("canot open %q: %w", path, err)
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return fmt.Errorf("cannot read %q: %w", path, err)
}
f.data = string(data)
hash := sha256.Sum256(data)
for i := range f.hash {
f.hash[i] = ^hash[i]
}
}
files = append(files, f)
return nil
})
if err != nil {
l.Panic("Failed to initialize static resources: %s", err)
}
sort.Slice(files, func(i, j int) bool {
fi, fj := files[i], files[j]
di, ei, _ := split(fi.name)
dj, ej, _ := split(fj.name)
if di != dj {
return di < dj
}
return ei < ej
})
var embedFS FS
embedFS.files = &files
return embedFS
}
// Eject 抽离内置静态资源
func Eject(l logging.Logger, statics fs.FS) error {
// 初始化静态资源
embedFS, err := fs.Sub(statics, "assets/build")
if err != nil {
l.Panic("Failed to initialize static resources: %s", err)
}
var walk func(relPath string, d fs.DirEntry, err error) error
walk = func(relPath string, d fs.DirEntry, err error) error {
if err != nil {
return fmt.Errorf("failed to read info of %q: %s, skipping...", relPath, err)
}
if !d.IsDir() {
// 写入文件
dst := util.DataPath(filepath.Join(StaticFolder, relPath))
out, err := util.CreatNestedFile(dst)
defer out.Close()
if err != nil {
return fmt.Errorf("failed to create file %q: %s, skipping...", dst, err)
}
l.Info("Ejecting %q...", dst)
obj, _ := embedFS.Open(relPath)
if _, err := io.Copy(out, bufio.NewReader(obj)); err != nil {
return fmt.Errorf("cannot write file %q: %s, skipping...", relPath, err)
}
}
return nil
}
// util.Log().Info("开始导出内置静态资源...")
err = fs.WalkDir(embedFS, ".", walk)
if err != nil {
return fmt.Errorf("failed to eject static resources: %w", err)
}
l.Info("Finish ejecting static resources.")
return nil
}