Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
161
pkg/auth/auth.go
161
pkg/auth/auth.go
@@ -2,18 +2,18 @@ package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,37 +23,59 @@ var (
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil)
|
||||
)
|
||||
|
||||
const CrHeaderPrefix = "X-Cr-"
|
||||
const (
|
||||
TokenHeaderPrefixCr = "Bearer Cr "
|
||||
)
|
||||
|
||||
// General 通用的认证接口
|
||||
// Deprecated
|
||||
var General Auth
|
||||
|
||||
// Auth 鉴权认证
|
||||
type Auth interface {
|
||||
// 对给定Body进行签名,expires为0表示永不过期
|
||||
Sign(body string, expires int64) string
|
||||
// 对给定Body和Sign进行检查
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
type (
|
||||
// Auth 鉴权认证
|
||||
Auth interface {
|
||||
// 对给定Body进行签名,expires为0表示永不过期
|
||||
Sign(body string, expires int64) string
|
||||
// 对给定Body和Sign进行检查
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
)
|
||||
|
||||
// SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||
// 请求正文、`X-Cr-`开头的header进行签名
|
||||
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||
func SignRequest(ctx context.Context, instance Auth, r *http.Request, expires *time.Time) *http.Request {
|
||||
// 处理有效期
|
||||
expireTime := int64(0)
|
||||
if expires != nil {
|
||||
expireTime = expires.Unix()
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getSignContent(ctx, r), expireTime)
|
||||
|
||||
// 将签名加到请求Header中
|
||||
r.Header["Authorization"] = []string{TokenHeaderPrefixCr + sign}
|
||||
return r
|
||||
}
|
||||
|
||||
// SignRequestDeprecated 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||
// 请求正文、`X-Cr-`开头的header进行签名
|
||||
func SignRequestDeprecated(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||
// 处理有效期
|
||||
if expires > 0 {
|
||||
expires += time.Now().Unix()
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getSignContent(r), expires)
|
||||
sign := instance.Sign(getSignContent(context.Background(), r), expires)
|
||||
|
||||
// 将签名加到请求Header中
|
||||
r.Header["Authorization"] = []string{"Bearer " + sign}
|
||||
r.Header["Authorization"] = []string{TokenHeaderPrefixCr + sign}
|
||||
return r
|
||||
}
|
||||
|
||||
// CheckRequest 对复杂请求进行签名验证
|
||||
func CheckRequest(instance Auth, r *http.Request) error {
|
||||
func CheckRequest(ctx context.Context, instance Auth, r *http.Request) error {
|
||||
var (
|
||||
sign []string
|
||||
ok bool
|
||||
@@ -61,41 +83,71 @@ func CheckRequest(instance Auth, r *http.Request) error {
|
||||
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||
return ErrAuthHeaderMissing
|
||||
}
|
||||
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||
sign[0] = strings.TrimPrefix(sign[0], TokenHeaderPrefixCr)
|
||||
|
||||
return instance.Check(getSignContent(ctx, r), sign[0])
|
||||
}
|
||||
|
||||
func isUploadDataRequest(r *http.Request) bool {
|
||||
return strings.Contains(r.URL.Path, constants.APIPrefix+"/slave/upload/") && r.Method != http.MethodPut
|
||||
|
||||
return instance.Check(getSignContent(r), sign[0])
|
||||
}
|
||||
|
||||
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API,
|
||||
// 则不对正文签名。返回待签名/验证的字符串
|
||||
func getSignContent(r *http.Request) (rawSignString string) {
|
||||
func getSignContent(ctx context.Context, r *http.Request) (rawSignString string) {
|
||||
// 读取所有body正文
|
||||
var body = []byte{}
|
||||
if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") {
|
||||
if !isUploadDataRequest(r) {
|
||||
if r.Body != nil {
|
||||
body, _ = ioutil.ReadAll(r.Body)
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
}
|
||||
|
||||
// 决定要签名的header
|
||||
var signedHeader []string
|
||||
for k, _ := range r.Header {
|
||||
if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" {
|
||||
if strings.HasPrefix(k, constants.CrHeaderPrefix) && k != constants.CrHeaderPrefix+"Filename" {
|
||||
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
|
||||
}
|
||||
}
|
||||
sort.Strings(signedHeader)
|
||||
|
||||
// 读取所有待签名Header
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
|
||||
rawSignString = serializer.NewRequestSignString(getUrlSignContent(ctx, r.URL), strings.Join(signedHeader, "&"), string(body))
|
||||
|
||||
return rawSignString
|
||||
}
|
||||
|
||||
// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||
func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
// SignURI 对URI进行签名
|
||||
func SignURI(ctx context.Context, instance Auth, uri string, expires *time.Time) (*url.URL, error) {
|
||||
// 处理有效期
|
||||
expireTime := int64(0)
|
||||
if expires != nil {
|
||||
expireTime = expires.Unix()
|
||||
}
|
||||
|
||||
base, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getUrlSignContent(ctx, base), expireTime)
|
||||
|
||||
// 将签名加到URI中
|
||||
queries := base.Query()
|
||||
queries.Set("sign", sign)
|
||||
base.RawQuery = queries.Encode()
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// SignURIDeprecated 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||
// Deprecated
|
||||
func SignURIDeprecated(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
// 处理有效期
|
||||
if expires != 0 {
|
||||
expires += time.Now().Unix()
|
||||
@@ -118,28 +170,55 @@ func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
}
|
||||
|
||||
// CheckURI 对URI进行鉴权
|
||||
func CheckURI(instance Auth, url *url.URL) error {
|
||||
func CheckURI(ctx context.Context, instance Auth, url *url.URL) error {
|
||||
//获取待验证的签名正文
|
||||
queries := url.Query()
|
||||
sign := queries.Get("sign")
|
||||
queries.Del("sign")
|
||||
url.RawQuery = queries.Encode()
|
||||
|
||||
return instance.Check(url.Path, sign)
|
||||
return instance.Check(getUrlSignContent(ctx, url), sign)
|
||||
}
|
||||
|
||||
// Init 初始化通用鉴权器
|
||||
func Init() {
|
||||
var secretKey string
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
secretKey = model.GetSettingByName("secret_key")
|
||||
} else {
|
||||
secretKey = conf.SlaveConfig.Secret
|
||||
if secretKey == "" {
|
||||
util.Log().Panic("SlaveSecret is not set, please specify it in config file.")
|
||||
func RedactSensitiveValues(errorMessage string) string {
|
||||
// Regular expression to match URLs
|
||||
urlRegex := regexp.MustCompile(`https?://[^\s]+`)
|
||||
// Find all URLs in the error message
|
||||
urls := urlRegex.FindAllString(errorMessage, -1)
|
||||
|
||||
for _, urlStr := range urls {
|
||||
// Parse the URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the query parameters
|
||||
queryParams := parsedURL.Query()
|
||||
|
||||
// Redact the 'sign' parameter if it exists
|
||||
if _, exists := queryParams["sign"]; exists {
|
||||
queryParams.Set("sign", "REDACTED")
|
||||
parsedURL.RawQuery = queryParams.Encode()
|
||||
}
|
||||
|
||||
// Replace the original URL with the redacted one in the error message
|
||||
errorMessage = strings.Replace(errorMessage, urlStr, parsedURL.String(), -1)
|
||||
}
|
||||
General = HMACAuth{
|
||||
SecretKey: []byte(secretKey),
|
||||
}
|
||||
|
||||
return errorMessage
|
||||
}
|
||||
|
||||
func getUrlSignContent(ctx context.Context, url *url.URL) string {
|
||||
// host := url.Host
|
||||
// if host == "" {
|
||||
// reqInfo := requestinfo.RequestInfoFromContext(ctx)
|
||||
// if reqInfo != nil {
|
||||
// host = reqInfo.Host
|
||||
// }
|
||||
// }
|
||||
// host = strings.TrimSuffix(host, "/")
|
||||
// // remove port if it exists
|
||||
// host = strings.Split(host, ":")[0]
|
||||
return url.Path
|
||||
}
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSignURI(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 成功
|
||||
{
|
||||
sign, err := SignURI(General, "/api/v3/something?id=1", 0)
|
||||
asserts.NoError(err)
|
||||
queries := sign.Query()
|
||||
asserts.Equal("1", queries.Get("id"))
|
||||
asserts.NotEmpty(queries.Get("sign"))
|
||||
}
|
||||
|
||||
// URI解码失败
|
||||
{
|
||||
sign, err := SignURI(General, "://dg.;'f]gh./'", 0)
|
||||
asserts.Error(err)
|
||||
asserts.Nil(sign)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURI(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 成功
|
||||
{
|
||||
sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10)
|
||||
asserts.NoError(err)
|
||||
asserts.NoError(CheckURI(General, sign))
|
||||
}
|
||||
|
||||
// 过期
|
||||
{
|
||||
sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1)
|
||||
asserts.NoError(err)
|
||||
asserts.Error(CheckURI(General, sign))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignRequest(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 非上传请求
|
||||
{
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body."))
|
||||
asserts.NoError(err)
|
||||
req = SignRequest(General, req, 0)
|
||||
asserts.NotEmpty(req.Header["Authorization"])
|
||||
}
|
||||
|
||||
// 上传请求
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/slave/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
|
||||
req = SignRequest(General, req, 10)
|
||||
asserts.NotEmpty(req.Header["Authorization"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRequest(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 缺少请求头
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.Error(err)
|
||||
asserts.Equal(ErrAuthHeaderMissing, err)
|
||||
}
|
||||
|
||||
// 非上传请求 验证成功
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req = SignRequest(General, req, 0)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
// 上传请求 验证成功
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
|
||||
req = SignRequest(General, req, 0)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
// 非上传请求 失败
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req = SignRequest(General, req, 0)
|
||||
req.Body = ioutil.NopCloser(strings.NewReader("2333"))
|
||||
err = CheckRequest(General, req)
|
||||
asserts.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
)
|
||||
|
||||
// HMACAuth HMAC算法鉴权
|
||||
@@ -39,7 +41,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
||||
// 验证是否过期
|
||||
expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64)
|
||||
if err != nil {
|
||||
return ErrAuthFailed.WithError(err)
|
||||
return serializer.NewError(serializer.CodeInvalidSign, "sign expired", nil)
|
||||
}
|
||||
// 如果签名过期
|
||||
if expires < time.Now().Unix() && expires != 0 {
|
||||
@@ -48,7 +50,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
||||
|
||||
// 验证签名
|
||||
if auth.Sign(body, expires) != sign {
|
||||
return ErrAuthFailed
|
||||
return serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// 设置gin为测试模式
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 初始化sqlmock
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
|
||||
mockDB, _ := gorm.Open("mysql", db)
|
||||
model.DB = mockDB
|
||||
defer db.Close()
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestHMACAuth_Sign(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth := HMACAuth{
|
||||
SecretKey: []byte(util.RandStringRunes(256)),
|
||||
}
|
||||
|
||||
asserts.NotEmpty(auth.Sign("content", 0))
|
||||
}
|
||||
|
||||
func TestHMACAuth_Check(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth := HMACAuth{
|
||||
SecretKey: []byte(util.RandStringRunes(256)),
|
||||
}
|
||||
|
||||
// 正常,永不过期
|
||||
{
|
||||
sign := auth.Sign("content", 0)
|
||||
asserts.NoError(auth.Check("content", sign))
|
||||
}
|
||||
|
||||
// 过期
|
||||
{
|
||||
sign := auth.Sign("content", 1)
|
||||
asserts.Error(auth.Check("content", sign))
|
||||
}
|
||||
|
||||
// 签名格式错误
|
||||
{
|
||||
sign := auth.Sign("content", 1)
|
||||
asserts.Error(auth.Check("content", sign+":"))
|
||||
}
|
||||
|
||||
// 过期日期格式错误
|
||||
{
|
||||
asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed"))
|
||||
}
|
||||
|
||||
// 签名有误
|
||||
{
|
||||
asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
|
||||
Init()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
|
||||
// slave模式
|
||||
conf.SystemConfig.Mode = "slave"
|
||||
asserts.Panics(func() {
|
||||
Init()
|
||||
})
|
||||
}
|
||||
200
pkg/auth/jwt.go
Normal file
200
pkg/auth/jwt.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/setting"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type TokenAuth interface {
|
||||
// Issue issues a new pair of credentials for the given user.
|
||||
Issue(ctx context.Context, u *ent.User) (*Token, error)
|
||||
// VerifyAndRetrieveUser verifies the given token and inject the user into current context.
|
||||
// Returns if upper caller should continue process other session provider.
|
||||
VerifyAndRetrieveUser(c *gin.Context) (bool, error)
|
||||
// Refresh refreshes the given refresh token and returns a new pair of credentials.
|
||||
Refresh(ctx context.Context, refreshToken string) (*Token, error)
|
||||
}
|
||||
|
||||
// Token stores token pair for authentication
|
||||
type Token struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccessExpires time.Time `json:"access_expires"`
|
||||
RefreshExpires time.Time `json:"refresh_expires"`
|
||||
|
||||
UID int `json:"-"`
|
||||
}
|
||||
|
||||
type (
|
||||
TokenType string
|
||||
TokenIDContextKey struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
TokenTypeAccess = TokenType("access")
|
||||
TokenTypeRefresh = TokenType("refresh")
|
||||
|
||||
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
const (
|
||||
AuthorizationHeader = "Authorization"
|
||||
TokenHeaderPrefix = "Bearer "
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
TokenType TokenType `json:"token_type"`
|
||||
jwt.RegisteredClaims
|
||||
StateHash []byte `json:"state_hash,omitempty"`
|
||||
}
|
||||
|
||||
// NewTokenAuth creates a new token based auth provider.
|
||||
func NewTokenAuth(idEncoder hashid.Encoder, s setting.Provider, secret []byte, userClient inventory.UserClient, l logging.Logger) TokenAuth {
|
||||
return &tokenAuth{
|
||||
idEncoder: idEncoder,
|
||||
s: s,
|
||||
secret: secret,
|
||||
userClient: userClient,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
type tokenAuth struct {
|
||||
l logging.Logger
|
||||
idEncoder hashid.Encoder
|
||||
s setting.Provider
|
||||
secret []byte
|
||||
userClient inventory.UserClient
|
||||
}
|
||||
|
||||
func (t *tokenAuth) Refresh(ctx context.Context, refreshToken string) (*Token, error) {
|
||||
token, err := jwt.ParseWithClaims(refreshToken, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return t.secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || claims.TokenType != TokenTypeRefresh {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID)
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
expectedUser, err := t.userClient.GetActiveByID(ctx, uid)
|
||||
if err != nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
// Check if user changed password or revoked session
|
||||
expectedHash := t.hashUserState(ctx, expectedUser)
|
||||
if !bytes.Equal(claims.StateHash, expectedHash[:]) {
|
||||
return nil, ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
return t.Issue(ctx, expectedUser)
|
||||
}
|
||||
|
||||
func (t *tokenAuth) VerifyAndRetrieveUser(c *gin.Context) (bool, error) {
|
||||
headerVal := c.GetHeader(AuthorizationHeader)
|
||||
if strings.HasPrefix(headerVal, TokenHeaderPrefixCr) {
|
||||
// This is an HMAC auth header, skip JWT verification
|
||||
return false, nil
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(headerVal, TokenHeaderPrefix)
|
||||
if tokenString == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return t.secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.l.Warning("Failed to parse jwt token: %s", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || claims.TokenType != TokenTypeAccess {
|
||||
return false, serializer.NewError(serializer.CodeCredentialInvalid, "Invalid token type", nil)
|
||||
}
|
||||
|
||||
uid, err := t.idEncoder.Decode(claims.Subject, hashid.UserID)
|
||||
if err != nil {
|
||||
return false, serializer.NewError(serializer.CodeNotFound, "User not found", err)
|
||||
}
|
||||
|
||||
util.WithValue(c, inventory.UserIDCtx{}, uid)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t *tokenAuth) Issue(ctx context.Context, u *ent.User) (*Token, error) {
|
||||
uidEncoded := hashid.EncodeUserID(t.idEncoder, u.ID)
|
||||
tokenSettings := t.s.TokenAuth(ctx)
|
||||
issueDate := time.Now()
|
||||
accessTokenExpired := time.Now().Add(tokenSettings.AccessTokenTTL)
|
||||
refreshTokenExpired := time.Now().Add(tokenSettings.RefreshTokenTTL)
|
||||
|
||||
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
||||
TokenType: TokenTypeAccess,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: uidEncoded,
|
||||
NotBefore: jwt.NewNumericDate(issueDate),
|
||||
ExpiresAt: jwt.NewNumericDate(accessTokenExpired),
|
||||
},
|
||||
}).SignedString(t.secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faield to sign access token: %w", err)
|
||||
}
|
||||
|
||||
userHash := t.hashUserState(ctx, u)
|
||||
refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, Claims{
|
||||
TokenType: TokenTypeRefresh,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: uidEncoded,
|
||||
NotBefore: jwt.NewNumericDate(issueDate),
|
||||
ExpiresAt: jwt.NewNumericDate(refreshTokenExpired),
|
||||
},
|
||||
StateHash: userHash[:],
|
||||
}).SignedString(t.secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faield to sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
AccessExpires: accessTokenExpired,
|
||||
RefreshExpires: refreshTokenExpired,
|
||||
UID: u.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// hashUserState returns a hash string for user state for critical fields, it is used
|
||||
// to detect refresh token revocation after user changed password.
|
||||
func (t *tokenAuth) hashUserState(ctx context.Context, u *ent.User) [32]byte {
|
||||
return sha256.Sum256([]byte(fmt.Sprintf("%s/%s/%s", u.Email, u.Password, t.s.SiteBasic(ctx).ID)))
|
||||
}
|
||||
25
pkg/auth/requestinfo/requestinfo.go
Normal file
25
pkg/auth/requestinfo/requestinfo.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package requestinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// RequestInfoCtx context key for RequestInfo
|
||||
type RequestInfoCtx struct{}
|
||||
|
||||
// RequestInfoFromContext retrieves RequestInfo from context
|
||||
func RequestInfoFromContext(ctx context.Context) *RequestInfo {
|
||||
v, ok := ctx.Value(RequestInfoCtx{}).(*RequestInfo)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// RequestInfo store request info for audit
|
||||
type RequestInfo struct {
|
||||
Host string
|
||||
IP string
|
||||
UserAgent string
|
||||
}
|
||||
Reference in New Issue
Block a user