Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
271
pkg/filemanager/driver/onedrive/oauth.go
Normal file
271
pkg/filemanager/driver/onedrive/oauth.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/application/dependency"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTokenExpiryMargin = 600 // 10 minutes
|
||||
)
|
||||
|
||||
// Error 实现error接口
|
||||
func (err OAuthError) Error() string {
|
||||
return err.ErrorDescription
|
||||
}
|
||||
|
||||
// OAuthURL 获取OAuth认证页面URL
|
||||
func (client *client) OAuthURL(ctx context.Context, scope []string) string {
|
||||
query := url.Values{
|
||||
"client_id": {client.policy.BucketName},
|
||||
"scope": {strings.Join(scope, " ")},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {client.policy.Settings.OauthRedirect},
|
||||
"state": {strconv.Itoa(client.policy.ID)},
|
||||
}
|
||||
client.endpoints.oAuthEndpoints.authorize.RawQuery = query.Encode()
|
||||
return client.endpoints.oAuthEndpoints.authorize.String()
|
||||
}
|
||||
|
||||
// getOAuthEndpoint gets OAuth endpoints from API endpoint
|
||||
func getOAuthEndpoint(apiEndpoint string) *oauthEndpoint {
|
||||
base, err := url.Parse(apiEndpoint)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
token *url.URL
|
||||
authorize *url.URL
|
||||
)
|
||||
switch base.Host {
|
||||
//case "login.live.com":
|
||||
// token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
|
||||
// authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
|
||||
case "microsoftgraph.chinacloudapi.cn":
|
||||
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
|
||||
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
|
||||
default:
|
||||
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
|
||||
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
|
||||
}
|
||||
|
||||
return &oauthEndpoint{
|
||||
token: *token,
|
||||
authorize: *authorize,
|
||||
}
|
||||
}
|
||||
|
||||
// Credential 获取token时返回的凭证
|
||||
type Credential struct {
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RefreshedAtUnix int64 `json:"refreshed_at"`
|
||||
|
||||
PolicyID int `json:"policy_id"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
||||
|
||||
func (c Credential) Refresh(ctx context.Context) (credmanager.Credential, error) {
|
||||
if c.RefreshToken == "" {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
dep := dependency.FromContext(ctx)
|
||||
storagePolicyClient := dep.StoragePolicyClient()
|
||||
policy, err := storagePolicyClient.GetPolicyByID(ctx, c.PolicyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get storage policy: %w", err)
|
||||
}
|
||||
|
||||
oauthBase := getOAuthEndpoint(policy.Server)
|
||||
|
||||
newCredential, err := obtainToken(ctx, &obtainTokenArgs{
|
||||
clientId: policy.BucketName,
|
||||
redirect: policy.Settings.OauthRedirect,
|
||||
secret: policy.SecretKey,
|
||||
refreshToken: c.RefreshToken,
|
||||
client: dep.RequestClient(request.WithLogger(dep.Logger())),
|
||||
tokenEndpoint: oauthBase.token.String(),
|
||||
policyID: c.PolicyID,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RefreshToken = newCredential.RefreshToken
|
||||
c.AccessToken = newCredential.AccessToken
|
||||
c.ExpiresIn = newCredential.ExpiresIn
|
||||
c.RefreshedAtUnix = time.Now().Unix()
|
||||
|
||||
// Write refresh token to db
|
||||
if err := storagePolicyClient.UpdateAccessKey(ctx, policy, newCredential.RefreshToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c Credential) Key() string {
|
||||
return CredentialKey(c.PolicyID)
|
||||
}
|
||||
|
||||
func (c Credential) Expiry() time.Time {
|
||||
return time.Unix(c.ExpiresIn-AccessTokenExpiryMargin, 0)
|
||||
}
|
||||
|
||||
func (c Credential) String() string {
|
||||
return c.AccessToken
|
||||
}
|
||||
|
||||
func (c Credential) RefreshedAt() *time.Time {
|
||||
if c.RefreshedAtUnix == 0 {
|
||||
return nil
|
||||
}
|
||||
refreshedAt := time.Unix(c.RefreshedAtUnix, 0)
|
||||
return &refreshedAt
|
||||
}
|
||||
|
||||
// ObtainToken 通过code或refresh_token兑换token
|
||||
func (client *client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
return obtainToken(ctx, &obtainTokenArgs{
|
||||
clientId: client.policy.BucketName,
|
||||
redirect: client.policy.Settings.OauthRedirect,
|
||||
secret: client.policy.SecretKey,
|
||||
code: options.code,
|
||||
refreshToken: options.refreshToken,
|
||||
client: client.httpClient,
|
||||
tokenEndpoint: client.endpoints.oAuthEndpoints.token.String(),
|
||||
policyID: client.policy.ID,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
type obtainTokenArgs struct {
|
||||
clientId string
|
||||
redirect string
|
||||
secret string
|
||||
code string
|
||||
refreshToken string
|
||||
client request.Client
|
||||
tokenEndpoint string
|
||||
policyID int
|
||||
}
|
||||
|
||||
// obtainToken fetch new access token from Microsoft Graph API
|
||||
func obtainToken(ctx context.Context, args *obtainTokenArgs) (*Credential, error) {
|
||||
body := url.Values{
|
||||
"client_id": {args.clientId},
|
||||
"redirect_uri": {args.redirect},
|
||||
"client_secret": {args.secret},
|
||||
}
|
||||
if args.code != "" {
|
||||
body.Add("grant_type", "authorization_code")
|
||||
body.Add("code", args.code)
|
||||
} else {
|
||||
body.Add("grant_type", "refresh_token")
|
||||
body.Add("refresh_token", args.refreshToken)
|
||||
}
|
||||
strBody := body.Encode()
|
||||
|
||||
res := args.client.Request(
|
||||
"POST",
|
||||
args.tokenEndpoint,
|
||||
io.NopCloser(strings.NewReader(strBody)),
|
||||
request.WithHeader(http.Header{
|
||||
"Content-Type": {"application/x-www-form-urlencoded"}},
|
||||
),
|
||||
request.WithContentLength(int64(len(strBody))),
|
||||
request.WithContext(ctx),
|
||||
)
|
||||
if res.Err != nil {
|
||||
return nil, res.Err
|
||||
}
|
||||
|
||||
respBody, err := res.GetResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
errResp OAuthError
|
||||
credential Credential
|
||||
decodeErr error
|
||||
)
|
||||
|
||||
if res.Response.StatusCode != 200 {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||
} else {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &credential)
|
||||
}
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
if errResp.ErrorType != "" {
|
||||
return nil, errResp
|
||||
}
|
||||
|
||||
credential.PolicyID = args.policyID
|
||||
credential.ExpiresIn = time.Now().Unix() + credential.ExpiresIn
|
||||
if args.code != "" {
|
||||
credential.ExpiresIn = time.Now().Unix() - 10
|
||||
}
|
||||
return &credential, nil
|
||||
}
|
||||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *client) UpdateCredential(ctx context.Context) error {
|
||||
newCred, err := client.cred.Obtain(ctx, CredentialKey(client.policy.ID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to obtain token from CredManager: %w", err)
|
||||
}
|
||||
|
||||
client.credential = newCred
|
||||
return nil
|
||||
}
|
||||
|
||||
// RetrieveOneDriveCredentials retrieves OneDrive credentials from DB inventory
|
||||
func RetrieveOneDriveCredentials(ctx context.Context, storagePolicyClient inventory.StoragePolicyClient) ([]credmanager.Credential, error) {
|
||||
odPolicies, err := storagePolicyClient.ListPolicyByType(ctx, types.PolicyTypeOd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list OneDrive policies: %w", err)
|
||||
}
|
||||
|
||||
return lo.Map(odPolicies, func(item *ent.StoragePolicy, index int) credmanager.Credential {
|
||||
return &Credential{
|
||||
PolicyID: item.ID,
|
||||
ExpiresIn: 0,
|
||||
RefreshToken: item.AccessKey,
|
||||
}
|
||||
}), nil
|
||||
}
|
||||
|
||||
func CredentialKey(policyId int) string {
|
||||
return fmt.Sprintf("cred_od_%d", policyId)
|
||||
}
|
||||
Reference in New Issue
Block a user