Init V4 community edition (#2265)

* Init V4 community edition

* Init V4 community edition
This commit is contained in:
AaronLiu
2025-04-20 17:31:25 +08:00
committed by GitHub
parent da4e44b77a
commit 21d158db07
597 changed files with 119415 additions and 41692 deletions

View File

@@ -0,0 +1,517 @@
package onedrive
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/chunk/backoff"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"io"
"net/http"
"net/url"
"path"
"strings"
"time"
)
const (
// SmallFileSize 单文件上传接口最大尺寸
SmallFileSize uint64 = 4 * 1024 * 1024
// ChunkSize 服务端中转分片上传分片大小
ChunkSize uint64 = 10 * 1024 * 1024
// ListRetry 列取请求重试次数
ListRetry = 1
chunkRetrySleep = time.Second * 5
notFoundError = "itemNotFound"
)
type RetryCtx struct{}
// GetSourcePath 获取文件的绝对路径
func (info *FileInfo) GetSourcePath() string {
res, err := url.PathUnescape(info.ParentReference.Path)
if err != nil {
return ""
}
return strings.TrimPrefix(
path.Join(
strings.TrimPrefix(res, "/drive/root:"),
info.Name,
),
"/",
)
}
func (client *client) getRequestURL(api string, opts ...Option) string {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
base, _ := url.Parse(client.endpoints.endpointURL)
if base == nil {
return ""
}
if options.useDriverResource {
base.Path = path.Join(base.Path, client.endpoints.driverResource, api)
} else {
base.Path = path.Join(base.Path, api)
}
return base.String()
}
// ListChildren 根据路径列取子对象
func (client *client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
var requestURL string
dst := strings.TrimPrefix(path, "/")
if dst == "" {
requestURL = client.getRequestURL("root/children")
} else {
requestURL = client.getRequestURL("root:/" + dst + ":/children")
}
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
if err != nil {
retried := 0
if v, ok := ctx.Value(RetryCtx{}).(int); ok {
retried = v
}
if retried < ListRetry {
retried++
client.l.Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
time.Sleep(time.Duration(5) * time.Second)
return client.ListChildren(context.WithValue(ctx, RetryCtx{}, retried), path)
}
return nil, err
}
var (
decodeErr error
fileInfo ListResponse
)
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
if decodeErr != nil {
return nil, decodeErr
}
return fileInfo.Value, nil
}
// Meta 根据资源ID或文件路径获取文件元信息
func (client *client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
var requestURL string
if id != "" {
requestURL = client.getRequestURL("items/" + id)
} else {
dst := strings.TrimPrefix(path, "/")
requestURL = client.getRequestURL("root:/" + dst)
}
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
if err != nil {
return nil, err
}
var (
decodeErr error
fileInfo FileInfo
)
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
if decodeErr != nil {
return nil, decodeErr
}
return &fileInfo, nil
}
// CreateUploadSession 创建分片上传会话
func (client *client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
body := map[string]map[string]interface{}{
"item": {
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
},
}
bodyBytes, _ := json.Marshal(body)
res, err := client.requestWithStr(ctx, "POST", requestURL, string(bodyBytes), 200)
if err != nil {
return "", err
}
var (
decodeErr error
uploadSession UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
if decodeErr != nil {
return "", decodeErr
}
return uploadSession.UploadURL, nil
}
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
func (client *client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
siteUrlParsed, err := url.Parse(siteUrl)
if err != nil {
return "", err
}
hostName := siteUrlParsed.Hostname()
relativePath := strings.Trim(siteUrlParsed.Path, "/")
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if reqErr != nil {
return "", reqErr
}
var (
decodeErr error
siteInfo Site
)
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
if decodeErr != nil {
return "", decodeErr
}
return siteInfo.ID, nil
}
// GetUploadSessionStatus 查询上传会话状态
func (client *client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
if err != nil {
return nil, err
}
var (
decodeErr error
uploadSession UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadSession, nil
}
// UploadChunk 上传分片
func (client *client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
res, err := client.request(
ctx, "PUT", uploadURL, content,
request.WithContentLength(current.Length()),
request.WithHeader(http.Header{
"Content-Range": {current.RangeHeader()},
}),
request.WithoutHeader([]string{"Authorization", "Content-Type"}),
request.WithTimeout(0),
)
if err != nil {
return nil, fmt.Errorf("failed to upload OneDrive chunk #%d: %w", current.Index(), err)
}
if current.IsLast() {
return nil, nil
}
var (
decodeErr error
uploadRes UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadRes, nil
}
// Upload 上传文件
func (client *client) Upload(ctx context.Context, file *fs.UploadRequest) error {
// 决定是否覆盖文件
overwrite := "fail"
if file.Mode&fs.ModeOverwrite == fs.ModeOverwrite {
overwrite = "replace"
}
size := int(file.Props.Size)
dst := file.Props.SavePath
// 小文件,使用简单上传接口上传
if size <= int(SmallFileSize) {
_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
return err
}
// 大文件,进行分片
// 创建上传会话
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
if err != nil {
return err
}
// Initial chunk groups
chunks := chunk.NewChunkGroup(file, client.chunkSize, &backoff.ConstantBackoff{
Max: client.settings.ChunkRetryLimit(ctx),
Sleep: chunkRetrySleep,
}, client.settings.UseChunkBuffer(ctx), client.l, client.settings.TempPath(ctx))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
_, err := client.UploadChunk(ctx, uploadURL, content, current)
return err
}
// upload chunks
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
if err := client.DeleteUploadSession(ctx, uploadURL); err != nil {
client.l.Warning("Failed to delete upload session: %s", err)
}
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
return nil
}
// DeleteUploadSession 删除上传会话
func (client *client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
_, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204)
if err != nil {
return err
}
return nil
}
// SimpleUpload 上传小文件到dst
func (client *client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/" + dst + ":/content")
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
request.WithTimeout(0),
)
if err != nil {
return nil, err
}
var (
decodeErr error
uploadRes UploadResult
)
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadRes, nil
}
// BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为
// 20个一组调用Delete并行删除
func (client *client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
groupNum := len(dst)/20 + 1
finalRes := make([]string, 0, len(dst))
res := make([]string, 0, 20)
var err error
for i := 0; i < groupNum; i++ {
end := 20*i + 20
if i == groupNum-1 {
end = len(dst)
}
client.l.Debug("Delete file group: %v.", dst[20*i:end])
res, err = client.Delete(ctx, dst[20*i:end])
finalRes = append(finalRes, res...)
}
return finalRes, err
}
// Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误,
// 由于API限制最多删除20个
func (client *client) Delete(ctx context.Context, dst []string) ([]string, error) {
body := client.makeBatchDeleteRequestsBody(dst)
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
WithDriverResource(false)), body, 200)
if err != nil {
return dst, err
}
var (
decodeErr error
deleteRes BatchResponses
)
decodeErr = json.Unmarshal([]byte(res), &deleteRes)
if decodeErr != nil {
return dst, decodeErr
}
// 取得删除失败的文件
failed := getDeleteFailed(&deleteRes)
if len(failed) != 0 {
return failed, ErrDeleteFile
}
return failed, nil
}
func getDeleteFailed(res *BatchResponses) []string {
var failed = make([]string, 0, len(res.Responses))
for _, v := range res.Responses {
if v.Status != 204 && v.Status != 404 {
failed = append(failed, v.ID)
}
}
return failed
}
// makeBatchDeleteRequestsBody 生成批量删除请求正文
func (client *client) makeBatchDeleteRequestsBody(files []string) string {
req := BatchRequests{
Requests: make([]BatchRequest, len(files)),
}
for i, v := range files {
v = strings.TrimPrefix(v, "/")
filePath, _ := url.Parse("/" + client.endpoints.driverResource + "/root:/")
filePath.Path = path.Join(filePath.Path, v)
req.Requests[i] = BatchRequest{
ID: v,
Method: "DELETE",
URL: filePath.EscapedPath(),
}
}
res, _ := json.Marshal(req)
return string(res)
}
// GetThumbURL 获取给定尺寸的缩略图URL
func (client *client) GetThumbURL(ctx context.Context, dst string) (string, error) {
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if err != nil {
return "", err
}
var (
decodeErr error
thumbRes ThumbResponse
)
decodeErr = json.Unmarshal([]byte(res), &thumbRes)
if decodeErr != nil {
return "", decodeErr
}
if thumbRes.URL != "" {
return thumbRes.URL, nil
}
if len(thumbRes.Value) == 1 {
if res, ok := thumbRes.Value[0]["large"]; ok {
return res.(map[string]interface{})["url"].(string), nil
}
}
return "", ErrThumbSizeNotFound
}
func sysError(err error) *RespError {
return &RespError{APIError: APIError{
Code: "system",
Message: err.Error(),
}}
}
func (client *client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
// 获取凭证
err := client.UpdateCredential(ctx)
if err != nil {
return "", sysError(err)
}
opts := []request.Option{
request.WithHeader(http.Header{
"Authorization": {"Bearer " + client.credential.String()},
"Content-Type": {"application/json"},
}),
request.WithContext(ctx),
request.WithTPSLimit(
fmt.Sprintf("policy_%d", client.policy.ID),
client.policy.Settings.TPSLimit,
client.policy.Settings.TPSLimitBurst,
),
}
// 发送请求
res := client.httpClient.Request(
method,
url,
body,
append(opts, option...)...,
)
if res.Err != nil {
return "", sysError(res.Err)
}
respBody, err := res.GetResponse()
if err != nil {
return "", sysError(err)
}
// 解析请求响应
var (
errResp RespError
decodeErr error
)
// 如果有错误
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
if decodeErr != nil {
client.l.Debug("Onedrive returns unknown response: %s", respBody)
return "", sysError(decodeErr)
}
if res.Response.StatusCode == 429 {
client.l.Warning("OneDrive request is throttled.")
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
}
return "", &errResp
}
return respBody, nil
}
func (client *client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
// 发送请求
bodyReader := io.NopCloser(strings.NewReader(body))
return client.request(ctx, method, url, bodyReader,
request.WithContentLength(int64(len(body))),
)
}

View File

@@ -0,0 +1,90 @@
package onedrive
import (
"context"
"errors"
"io"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
)
var (
// ErrAuthEndpoint 无法解析授权端点地址
ErrAuthEndpoint = errors.New("failed to parse endpoint url")
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
// ErrDeleteFile 无法删除文件
ErrDeleteFile = errors.New("cannot delete file")
// ErrClientCanceled 客户端取消操作
ErrClientCanceled = errors.New("client canceled")
// Desired thumb size not available
ErrThumbSizeNotFound = errors.New("thumb size not found")
)
type Client interface {
ListChildren(ctx context.Context, path string) ([]FileInfo, error)
Meta(ctx context.Context, id string, path string) (*FileInfo, error)
CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error)
GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error)
GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error)
Upload(ctx context.Context, file *fs.UploadRequest) error
SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error)
DeleteUploadSession(ctx context.Context, uploadURL string) error
BatchDelete(ctx context.Context, dst []string) ([]string, error)
GetThumbURL(ctx context.Context, dst string) (string, error)
OAuthURL(ctx context.Context, scopes []string) string
ObtainToken(ctx context.Context, opts ...Option) (*Credential, error)
}
// client OneDrive客户端
type client struct {
endpoints *endpoints
policy *ent.StoragePolicy
credential credmanager.Credential
httpClient request.Client
cred credmanager.CredManager
l logging.Logger
settings setting.Provider
chunkSize int64
}
// endpoints OneDrive客户端相关设置
type endpoints struct {
oAuthEndpoints *oauthEndpoint
endpointURL string // 接口请求的基URL
driverResource string // 要使用的驱动器
}
// NewClient 根据存储策略获取新的client
func NewClient(policy *ent.StoragePolicy, httpClient request.Client, cred credmanager.CredManager,
l logging.Logger, settings setting.Provider, chunkSize int64) Client {
client := &client{
endpoints: &endpoints{
endpointURL: policy.Server,
driverResource: policy.Settings.OdDriver,
},
policy: policy,
httpClient: httpClient,
cred: cred,
l: l,
settings: settings,
chunkSize: chunkSize,
}
if client.endpoints.driverResource == "" {
client.endpoints.driverResource = "me/drive"
}
oauthBase := getOAuthEndpoint(policy.Server)
client.endpoints.oAuthEndpoints = oauthBase
return client
}

View File

@@ -0,0 +1,271 @@
package onedrive
import (
"context"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/samber/lo"
)
const (
AccessTokenExpiryMargin = 600 // 10 minutes
)
// Error 实现error接口
func (err OAuthError) Error() string {
return err.ErrorDescription
}
// OAuthURL 获取OAuth认证页面URL
func (client *client) OAuthURL(ctx context.Context, scope []string) string {
query := url.Values{
"client_id": {client.policy.BucketName},
"scope": {strings.Join(scope, " ")},
"response_type": {"code"},
"redirect_uri": {client.policy.Settings.OauthRedirect},
"state": {strconv.Itoa(client.policy.ID)},
}
client.endpoints.oAuthEndpoints.authorize.RawQuery = query.Encode()
return client.endpoints.oAuthEndpoints.authorize.String()
}
// getOAuthEndpoint gets OAuth endpoints from API endpoint
func getOAuthEndpoint(apiEndpoint string) *oauthEndpoint {
base, err := url.Parse(apiEndpoint)
if err != nil {
return nil
}
var (
token *url.URL
authorize *url.URL
)
switch base.Host {
//case "login.live.com":
// token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
// authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
case "microsoftgraph.chinacloudapi.cn":
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
default:
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
}
return &oauthEndpoint{
token: *token,
authorize: *authorize,
}
}
// Credential 获取token时返回的凭证
type Credential struct {
ExpiresIn int64 `json:"expires_in"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
RefreshedAtUnix int64 `json:"refreshed_at"`
PolicyID int `json:"policy_id"`
}
func init() {
gob.Register(Credential{})
}
func (c Credential) Refresh(ctx context.Context) (credmanager.Credential, error) {
if c.RefreshToken == "" {
return nil, ErrInvalidRefreshToken
}
dep := dependency.FromContext(ctx)
storagePolicyClient := dep.StoragePolicyClient()
policy, err := storagePolicyClient.GetPolicyByID(ctx, c.PolicyID)
if err != nil {
return nil, fmt.Errorf("failed to get storage policy: %w", err)
}
oauthBase := getOAuthEndpoint(policy.Server)
newCredential, err := obtainToken(ctx, &obtainTokenArgs{
clientId: policy.BucketName,
redirect: policy.Settings.OauthRedirect,
secret: policy.SecretKey,
refreshToken: c.RefreshToken,
client: dep.RequestClient(request.WithLogger(dep.Logger())),
tokenEndpoint: oauthBase.token.String(),
policyID: c.PolicyID,
})
if err != nil {
return nil, err
}
c.RefreshToken = newCredential.RefreshToken
c.AccessToken = newCredential.AccessToken
c.ExpiresIn = newCredential.ExpiresIn
c.RefreshedAtUnix = time.Now().Unix()
// Write refresh token to db
if err := storagePolicyClient.UpdateAccessKey(ctx, policy, newCredential.RefreshToken); err != nil {
return nil, err
}
return c, nil
}
func (c Credential) Key() string {
return CredentialKey(c.PolicyID)
}
func (c Credential) Expiry() time.Time {
return time.Unix(c.ExpiresIn-AccessTokenExpiryMargin, 0)
}
func (c Credential) String() string {
return c.AccessToken
}
func (c Credential) RefreshedAt() *time.Time {
if c.RefreshedAtUnix == 0 {
return nil
}
refreshedAt := time.Unix(c.RefreshedAtUnix, 0)
return &refreshedAt
}
// ObtainToken 通过code或refresh_token兑换token
func (client *client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
return obtainToken(ctx, &obtainTokenArgs{
clientId: client.policy.BucketName,
redirect: client.policy.Settings.OauthRedirect,
secret: client.policy.SecretKey,
code: options.code,
refreshToken: options.refreshToken,
client: client.httpClient,
tokenEndpoint: client.endpoints.oAuthEndpoints.token.String(),
policyID: client.policy.ID,
})
}
type obtainTokenArgs struct {
clientId string
redirect string
secret string
code string
refreshToken string
client request.Client
tokenEndpoint string
policyID int
}
// obtainToken fetch new access token from Microsoft Graph API
func obtainToken(ctx context.Context, args *obtainTokenArgs) (*Credential, error) {
body := url.Values{
"client_id": {args.clientId},
"redirect_uri": {args.redirect},
"client_secret": {args.secret},
}
if args.code != "" {
body.Add("grant_type", "authorization_code")
body.Add("code", args.code)
} else {
body.Add("grant_type", "refresh_token")
body.Add("refresh_token", args.refreshToken)
}
strBody := body.Encode()
res := args.client.Request(
"POST",
args.tokenEndpoint,
io.NopCloser(strings.NewReader(strBody)),
request.WithHeader(http.Header{
"Content-Type": {"application/x-www-form-urlencoded"}},
),
request.WithContentLength(int64(len(strBody))),
request.WithContext(ctx),
)
if res.Err != nil {
return nil, res.Err
}
respBody, err := res.GetResponse()
if err != nil {
return nil, err
}
var (
errResp OAuthError
credential Credential
decodeErr error
)
if res.Response.StatusCode != 200 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
} else {
decodeErr = json.Unmarshal([]byte(respBody), &credential)
}
if decodeErr != nil {
return nil, decodeErr
}
if errResp.ErrorType != "" {
return nil, errResp
}
credential.PolicyID = args.policyID
credential.ExpiresIn = time.Now().Unix() + credential.ExpiresIn
if args.code != "" {
credential.ExpiresIn = time.Now().Unix() - 10
}
return &credential, nil
}
// UpdateCredential 更新凭证,并检查有效期
func (client *client) UpdateCredential(ctx context.Context) error {
newCred, err := client.cred.Obtain(ctx, CredentialKey(client.policy.ID))
if err != nil {
return fmt.Errorf("failed to obtain token from CredManager: %w", err)
}
client.credential = newCred
return nil
}
// RetrieveOneDriveCredentials retrieves OneDrive credentials from DB inventory
func RetrieveOneDriveCredentials(ctx context.Context, storagePolicyClient inventory.StoragePolicyClient) ([]credmanager.Credential, error) {
odPolicies, err := storagePolicyClient.ListPolicyByType(ctx, types.PolicyTypeOd)
if err != nil {
return nil, fmt.Errorf("failed to list OneDrive policies: %w", err)
}
return lo.Map(odPolicies, func(item *ent.StoragePolicy, index int) credmanager.Credential {
return &Credential{
PolicyID: item.ID,
ExpiresIn: 0,
RefreshToken: item.AccessKey,
}
}), nil
}
func CredentialKey(policyId int) string {
return fmt.Sprintf("cred_od_%d", policyId)
}

View File

@@ -0,0 +1,247 @@
package onedrive
import (
"context"
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/boolset"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
"net/url"
"os"
"strings"
"time"
)
// Driver OneDrive 适配器
type Driver struct {
policy *ent.StoragePolicy
client Client
settings setting.Provider
config conf.ConfigProvider
l logging.Logger
chunkSize int64
}
var (
features = &boolset.BooleanSet{}
)
const (
streamSaverParam = "stream_saver"
)
func init() {
boolset.Sets(map[driver.HandlerCapability]bool{
driver.HandlerCapabilityUploadSentinelRequired: true,
}, features)
}
// NewDriver 从存储策略初始化新的Driver实例
func New(ctx context.Context, policy *ent.StoragePolicy, settings setting.Provider,
config conf.ConfigProvider, l logging.Logger, cred credmanager.CredManager) (*Driver, error) {
chunkSize := policy.Settings.ChunkSize
if policy.Settings.ChunkSize == 0 {
chunkSize = 50 << 20 // 50MB
}
c := NewClient(policy, request.NewClient(config, request.WithLogger(l)), cred, l, settings, chunkSize)
return &Driver{
policy: policy,
client: c,
settings: settings,
l: l,
config: config,
chunkSize: chunkSize,
}, nil
}
//// List 列取项目
//func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// base = strings.TrimPrefix(base, "/")
// // 列取子项目
// objects, _ := handler.client.ListChildren(ctx, base)
//
// // 获取真实的列取起始根目录
// rootPath := base
// if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok {
// rootPath = realBase
// } else {
// ctx = context.WithValue(ctx, fsctx.PathCtx, base)
// }
//
// // 整理结果
// res := make([]response.Object, 0, len(objects))
// for _, object := range objects {
// source := path.Join(base, object.Name)
// rel, err := filepath.Rel(rootPath, source)
// if err != nil {
// continue
// }
// res = append(res, response.Object{
// Name: object.Name,
// RelativePath: filepath.ToSlash(rel),
// Source: source,
// Size: uint64(object.Size),
// IsDir: object.Folder != nil,
// LastModify: time.Now(),
// })
// }
//
// // 递归列取子目录
// if recursive {
// for _, object := range objects {
// if object.Folder != nil {
// sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive)
// res = append(res, sub...)
// }
// }
// }
//
// return res, nil
//}
func (handler *Driver) Open(ctx context.Context, path string) (*os.File, error) {
return nil, errors.New("not implemented")
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file *fs.UploadRequest) error {
defer file.Close()
return handler.client.Upload(ctx, file)
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler *Driver) Delete(ctx context.Context, files ...string) ([]string, error) {
return handler.client.BatchDelete(ctx, files)
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, expire *time.Time, ext string, e fs.Entity) (string, error) {
res, err := handler.client.GetThumbURL(ctx, e.Source())
if err != nil {
var apiErr *RespError
if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) {
// OneDrive cannot generate thumbnail for this file
return "", fmt.Errorf("thumb not supported in OneDrive: %w", err)
}
}
return res, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, e fs.Entity, args *driver.GetSourceArgs) (string, error) {
// 缓存不存在,重新获取
res, err := handler.client.Meta(ctx, "", e.Source())
if err != nil {
return "", err
}
if args.IsDownload && handler.policy.Settings.StreamSaver {
downloadUrl := res.DownloadURL + "&" + streamSaverParam + "=" + url.QueryEscape(args.DisplayName)
return downloadUrl, nil
}
return res.DownloadURL, nil
}
// Token 获取上传会话URL
func (handler *Driver) Token(ctx context.Context, uploadSession *fs.UploadSession, file *fs.UploadRequest) (*fs.UploadCredential, error) {
// 生成回调地址
siteURL := handler.settings.SiteURL(setting.UseFirstSiteUrl(ctx))
uploadSession.Callback = routes.MasterSlaveCallbackUrl(siteURL, types.PolicyTypeOd, uploadSession.Props.UploadSessionID, uploadSession.CallbackSecret).String()
uploadURL, err := handler.client.CreateUploadSession(ctx, file.Props.SavePath, WithConflictBehavior("fail"))
if err != nil {
return nil, err
}
// 监控回调及上传
//go handler.client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl)
uploadSession.ChunkSize = handler.chunkSize
uploadSession.UploadURL = uploadURL
return &fs.UploadCredential{
ChunkSize: handler.chunkSize,
UploadURLs: []string{uploadURL},
}, nil
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *fs.UploadSession) error {
err := handler.client.DeleteUploadSession(ctx, uploadSession.UploadURL)
// Create empty placeholder file to stop upload
if err == nil {
_, err := handler.client.SimpleUpload(ctx, uploadSession.Props.SavePath, strings.NewReader(""), 0, WithConflictBehavior("replace"))
if err != nil {
handler.l.Warning("Failed to create placeholder file %q:%s", uploadSession.Props.SavePath, err)
}
}
return err
}
func (handler *Driver) CompleteUpload(ctx context.Context, session *fs.UploadSession) error {
if session.SentinelTaskID == 0 {
return nil
}
// Make sure uploaded file size is correct
res, err := handler.client.Meta(ctx, "", session.Props.SavePath)
if err != nil {
// Create empty placeholder file to stop further upload
return fmt.Errorf("failed to get uploaded file size: %w", err)
}
isSharePoint := strings.Contains(handler.policy.Settings.OdDriver, "sharepoint.com") ||
strings.Contains(handler.policy.Settings.OdDriver, "sharepoint.cn")
sizeMismatch := res.Size != session.Props.Size
// SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容
// See: https://github.com/OneDrive/onedrive-api-docs/issues/935
if isSharePoint && sizeMismatch && (res.Size > session.Props.Size) && (res.Size-session.Props.Size <= 1048576) {
sizeMismatch = false
}
if sizeMismatch {
return serializer.NewError(
serializer.CodeMetaMismatch,
fmt.Sprintf("File size not match, expected: %d, actual: %d", session.Props.Size, res.Size),
nil,
)
}
return nil
}
func (handler *Driver) Capabilities() *driver.Capabilities {
return &driver.Capabilities{
StaticFeatures: features,
ThumbSupportedExts: handler.policy.Settings.ThumbExts,
ThumbSupportAllExts: handler.policy.Settings.ThumbSupportAllExts,
ThumbMaxSize: handler.policy.Settings.ThumbMaxSize,
ThumbProxy: handler.policy.Settings.ThumbGeneratorProxy,
MediaMetaProxy: handler.policy.Settings.MediaMetaGeneratorProxy,
}
}
func (handler *Driver) MediaMeta(ctx context.Context, path, ext string) ([]driver.MediaMeta, error) {
return nil, errors.New("not implemented")
}
func (handler *Driver) LocalPath(ctx context.Context, path string) string {
return ""
}

View File

@@ -0,0 +1,59 @@
package onedrive
import "time"
// Option 发送请求的额外设置
type Option interface {
apply(*options)
}
type options struct {
redirect string
code string
refreshToken string
conflictBehavior string
expires time.Time
useDriverResource bool
}
type optionFunc func(*options)
// WithCode 设置接口Code
func WithCode(t string) Option {
return optionFunc(func(o *options) {
o.code = t
})
}
// WithRefreshToken 设置接口RefreshToken
func WithRefreshToken(t string) Option {
return optionFunc(func(o *options) {
o.refreshToken = t
})
}
// WithConflictBehavior 设置文件重名后的处理方式
func WithConflictBehavior(t string) Option {
return optionFunc(func(o *options) {
o.conflictBehavior = t
})
}
// WithConflictBehavior 设置文件重名后的处理方式
func WithDriverResource(t bool) Option {
return optionFunc(func(o *options) {
o.useDriverResource = t
})
}
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
conflictBehavior: "fail",
useDriverResource: true,
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
}
}

View File

@@ -0,0 +1,130 @@
package onedrive
import (
"encoding/gob"
"net/url"
)
// RespError 接口返回错误
type RespError struct {
APIError APIError `json:"error"`
}
// APIError 接口返回的错误内容
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// UploadSessionResponse 分片上传会话
type UploadSessionResponse struct {
DataContext string `json:"@odata.context"`
ExpirationDateTime string `json:"expirationDateTime"`
NextExpectedRanges []string `json:"nextExpectedRanges"`
UploadURL string `json:"uploadUrl"`
}
// FileInfo 文件元信息
type FileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
Image imageInfo `json:"image"`
ParentReference parentReference `json:"parentReference"`
DownloadURL string `json:"@microsoft.graph.downloadUrl"`
File *file `json:"file"`
Folder *folder `json:"folder"`
}
type file struct {
MimeType string `json:"mimeType"`
}
type folder struct {
ChildCount int `json:"childCount"`
}
type imageInfo struct {
Height int `json:"height"`
Width int `json:"width"`
}
type parentReference struct {
Path string `json:"path"`
Name string `json:"name"`
ID string `json:"id"`
}
// UploadResult 上传结果
type UploadResult struct {
ID string `json:"id"`
Name string `json:"name"`
Size uint64 `json:"size"`
}
// BatchRequests 批量操作请求
type BatchRequests struct {
Requests []BatchRequest `json:"requests"`
}
// BatchRequest 批量操作单个请求
type BatchRequest struct {
ID string `json:"id"`
Method string `json:"method"`
URL string `json:"url"`
Body interface{} `json:"body,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
}
// BatchResponses 批量操作响应
type BatchResponses struct {
Responses []BatchResponse `json:"responses"`
}
// BatchResponse 批量操作单个响应
type BatchResponse struct {
ID string `json:"id"`
Status int `json:"status"`
}
// ThumbResponse 获取缩略图的响应
type ThumbResponse struct {
Value []map[string]interface{} `json:"value"`
URL string `json:"url"`
}
// ListResponse 列取子项目响应
type ListResponse struct {
Value []FileInfo `json:"value"`
Context string `json:"@odata.context"`
}
// oauthEndpoint OAuth接口地址
type oauthEndpoint struct {
token url.URL
authorize url.URL
}
// OAuthError OAuth相关接口的错误响应
type OAuthError struct {
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
CorrelationID string `json:"correlation_id"`
}
// Site SharePoint 站点信息
type Site struct {
Description string `json:"description"`
ID string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
WebUrl string `json:"webUrl"`
}
func init() {
gob.Register(Credential{})
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}