fix(dbfs): enforce root protection for single file share

This commit is contained in:
Aaron Liu
2025-04-21 19:43:09 +08:00
parent d60e400f83
commit 7b5e0e8581
11 changed files with 32 additions and 17 deletions

View File

@@ -448,6 +448,10 @@ func (f *DBFS) Get(ctx context.Context, path *fs.URI, opts ...fs.Option) (fs.Fil
return nil, fmt.Errorf("failed to get target file: %w", err)
}
if o.notRoot && target.IsRootFolder() {
return nil, fs.ErrNotSupportedAction.WithError(fmt.Errorf("cannot operate root file"))
}
if o.extendedInfo && target != nil {
extendedInfo := &fs.FileExtendedInfo{
StorageUsed: target.SizeUsed(),

View File

@@ -25,6 +25,7 @@ type dbfsOption struct {
noChainedCreation bool
streamListResponseCallback func(parent fs.File, file []fs.File)
ancestor *File
notRoot bool
}
func newDbfsOption() *dbfsOption {
@@ -56,6 +57,13 @@ func WithFilePublicMetadata() fs.Option {
})
}
// WithNotRoot force the get result cannot be a root folder
func WithNotRoot() fs.Option {
return optionFunc(func(o *dbfsOption) {
o.notRoot = true
})
}
// WithContextHint enables generating context hint for the list operation.
func WithContextHint() fs.Option {
return optionFunc(func(o *dbfsOption) {

View File

@@ -183,7 +183,7 @@ func (n *shareNavigator) To(ctx context.Context, path *fs.URI) (*File, error) {
elements := path.Elements()
// If target is root of single file share, the root itself is the target.
if len(elements) <= 1 && n.singleFileShare {
if len(elements) == 1 && n.singleFileShare {
file, err := n.latestSharedSingleFile(ctx)
if err != nil {
return nil, err

View File

@@ -176,7 +176,7 @@ func (f *DBFS) PrepareUpload(ctx context.Context, req *fs.UploadRequest, opts ..
func (f *DBFS) CompleteUpload(ctx context.Context, session *fs.UploadSession) (fs.File, error) {
// Get placeholder file
file, err := f.Get(ctx, session.Props.Uri, WithFileEntities())
file, err := f.Get(ctx, session.Props.Uri, WithFileEntities(), WithNotRoot())
if err != nil {
return nil, fmt.Errorf("failed to get placeholder file: %w", err)
}
@@ -270,7 +270,7 @@ func (f *DBFS) CompleteUpload(ctx context.Context, session *fs.UploadSession) (f
}
}
file, err = f.Get(ctx, session.Props.Uri, WithFileEntities())
file, err = f.Get(ctx, session.Props.Uri, WithFileEntities(), WithNotRoot())
if err != nil {
return nil, fmt.Errorf("failed to get updated file: %w", err)
}
@@ -284,7 +284,7 @@ func (f *DBFS) CompleteUpload(ctx context.Context, session *fs.UploadSession) (f
// - File unlocked, upload session not valid
func (f *DBFS) CancelUploadSession(ctx context.Context, path *fs.URI, sessionID string, session *fs.UploadSession) ([]fs.Entity, error) {
// Get placeholder file
file, err := f.Get(ctx, path, WithFileEntities())
file, err := f.Get(ctx, path, WithFileEntities(), WithNotRoot())
if err != nil {
return nil, fmt.Errorf("failed to get placeholder file: %w", err)
}

View File

@@ -158,6 +158,7 @@ type (
ExtendedInfo() *FileExtendedInfo
FolderSummary() *FolderSummary
Capabilities() *boolset.BooleanSet
IsRootFolder() bool
}
Entities []Entity

View File

@@ -27,7 +27,7 @@ func (m *manager) CreateArchive(ctx context.Context, uris []*fs.URI, writer io.W
// List all top level files
files := make([]fs.File, 0, len(uris))
for _, uri := range uris {
file, err := m.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile))
file, err := m.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile), dbfs.WithNotRoot())
if err != nil {
return 0, fmt.Errorf("failed to get file %s: %w", uri, err)
}

View File

@@ -227,7 +227,7 @@ func (l *manager) Restore(ctx context.Context, path ...*fs.URI) error {
}
func (l *manager) CreateOrUpdateShare(ctx context.Context, path *fs.URI, args *CreateShareArgs) (*ent.Share, error) {
file, err := l.fs.Get(ctx, path, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityShare))
file, err := l.fs.Get(ctx, path, dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityShare), dbfs.WithNotRoot())
if err != nil {
return nil, serializer.NewError(serializer.CodeNotFound, "src file not found", err)
}

View File

@@ -45,7 +45,7 @@ func init() {
}
func (m *manager) CreateViewerSession(ctx context.Context, uri *fs.URI, version string, viewer *setting.Viewer) (*ViewerSession, error) {
file, err := m.fs.Get(ctx, uri, dbfs.WithFileEntities())
file, err := m.fs.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithNotRoot())
if err != nil {
return nil, err
}

View File

@@ -171,7 +171,7 @@ func (m *ExtractArchiveTask) createSlaveExtractTask(ctx context.Context, dep dep
fm := manager.NewFileManager(dep, user)
// Get entity source to extract
archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile))
archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile), dbfs.WithNotRoot())
if err != nil {
return task.StatusError, fmt.Errorf("failed to get archive file: %s (%w)", err, queue.CriticalErr)
}
@@ -256,7 +256,7 @@ func (m *ExtractArchiveTask) masterExtractArchive(ctx context.Context, dep depen
fm := manager.NewFileManager(dep, user)
// Get entity source to extract
archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile))
archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile), dbfs.WithNotRoot())
if err != nil {
return task.StatusError, fmt.Errorf("failed to get archive file: %s (%w)", err, queue.CriticalErr)
}
@@ -413,7 +413,7 @@ func (m *ExtractArchiveTask) masterDownloadZip(ctx context.Context, dep dependen
fm := manager.NewFileManager(dep, user)
// Get entity source to extract
archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile))
archiveFile, err := fm.Get(ctx, uri, dbfs.WithFileEntities(), dbfs.WithRequiredCapabilities(dbfs.NavigatorCapabilityDownloadFile), dbfs.WithNotRoot())
if err != nil {
return task.StatusError, fmt.Errorf("failed to get archive file: %s (%w)", err, queue.CriticalErr)
}