Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
@@ -2,7 +2,8 @@ package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -15,18 +16,24 @@ type Option interface {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
timeout time.Duration
|
||||
header http.Header
|
||||
sign auth.Auth
|
||||
signTTL int64
|
||||
ctx context.Context
|
||||
contentLength int64
|
||||
masterMeta bool
|
||||
endpoint *url.URL
|
||||
slaveNodeID string
|
||||
tpsLimiterToken string
|
||||
tps float64
|
||||
tpsBurst int
|
||||
timeout time.Duration
|
||||
header http.Header
|
||||
sign auth.Auth
|
||||
signTTL int64
|
||||
ctx context.Context
|
||||
contentLength int64
|
||||
masterMeta bool
|
||||
siteID string
|
||||
siteURL string
|
||||
endpoint *url.URL
|
||||
slaveNodeID int
|
||||
tpsLimiterToken string
|
||||
tps float64
|
||||
tpsBurst int
|
||||
logger logging.Logger
|
||||
withCorrelationID bool
|
||||
cookieJar http.CookieJar
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
@@ -38,7 +45,7 @@ func (f optionFunc) apply(o *options) {
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
header: http.Header{},
|
||||
timeout: time.Duration(30) * time.Second,
|
||||
timeout: 0,
|
||||
contentLength: -1,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
@@ -50,6 +57,13 @@ func (o *options) clone() options {
|
||||
return newOptions
|
||||
}
|
||||
|
||||
// WithTransport 设置请求Transport
|
||||
func WithTransport(transport *http.Transport) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.transport = transport
|
||||
})
|
||||
}
|
||||
|
||||
// WithTimeout 设置请求超时
|
||||
func WithTimeout(t time.Duration) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
@@ -68,7 +82,9 @@ func WithContext(c context.Context) Option {
|
||||
func WithCredential(instance auth.Auth, ttl int64) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.sign = instance
|
||||
o.signTTL = ttl
|
||||
if ttl > 0 {
|
||||
o.signTTL = ttl
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -99,14 +115,16 @@ func WithContentLength(s int64) Option {
|
||||
}
|
||||
|
||||
// WithMasterMeta 请求时携带主机信息
|
||||
func WithMasterMeta() Option {
|
||||
func WithMasterMeta(siteID string, siteURL string) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.masterMeta = true
|
||||
o.siteID = siteID
|
||||
o.siteURL = siteURL
|
||||
})
|
||||
}
|
||||
|
||||
// WithSlaveMeta 请求时携带从机信息
|
||||
func WithSlaveMeta(s string) Option {
|
||||
// WithSlaveMeta set slave node ID in master's request header
|
||||
func WithSlaveMeta(s int) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.slaveNodeID = s
|
||||
})
|
||||
@@ -135,3 +153,24 @@ func WithTPSLimit(token string, tps float64, burst int) Option {
|
||||
o.tpsBurst = burst
|
||||
})
|
||||
}
|
||||
|
||||
// WithLogger set logger for logging requests
|
||||
func WithLogger(logger logging.Logger) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.logger = logger
|
||||
})
|
||||
}
|
||||
|
||||
// WithCorrelationID set correlation ID for logging requests
|
||||
func WithCorrelationID() Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.withCorrelationID = true
|
||||
})
|
||||
}
|
||||
|
||||
// WithCookieJar set cookie jar for request
|
||||
func WithCookieJar(jar http.CookieJar) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.cookieJar = jar
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,34 +1,50 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v4/application/constants"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// GeneralClient 通用 HTTP Client
|
||||
var GeneralClient Client = NewClient()
|
||||
var GeneralClient Client = NewClientDeprecated()
|
||||
|
||||
const (
|
||||
CorrelationHeader = constants.CrHeaderPrefix + "Correlation-Id"
|
||||
SiteURLHeader = constants.CrHeaderPrefix + "Site-Url"
|
||||
SiteVersionHeader = constants.CrHeaderPrefix + "Version"
|
||||
SiteIDHeader = constants.CrHeaderPrefix + "Site-Id"
|
||||
SlaveNodeIDHeader = constants.CrHeaderPrefix + "Node-Id"
|
||||
LocalIP = "localhost"
|
||||
)
|
||||
|
||||
// Response 请求的响应或错误信息
|
||||
type Response struct {
|
||||
Err error
|
||||
Response *http.Response
|
||||
l logging.Logger
|
||||
}
|
||||
|
||||
// Client 请求客户端
|
||||
type Client interface {
|
||||
// Apply applies the given options to this client.
|
||||
Apply(opts ...Option)
|
||||
// Request send a HTTP request
|
||||
Request(method, target string, body io.Reader, opts ...Option) *Response
|
||||
}
|
||||
|
||||
@@ -37,9 +53,26 @@ type HTTPClient struct {
|
||||
mu sync.Mutex
|
||||
options *options
|
||||
tpsLimiter TPSLimiter
|
||||
l logging.Logger
|
||||
config conf.ConfigProvider
|
||||
}
|
||||
|
||||
func NewClient(opts ...Option) Client {
|
||||
func NewClient(config conf.ConfigProvider, opts ...Option) Client {
|
||||
client := &HTTPClient{
|
||||
options: newDefaultOption(),
|
||||
tpsLimiter: globalTPSLimiter,
|
||||
config: config,
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
o.apply(client.options)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func NewClientDeprecated(opts ...Option) Client {
|
||||
client := &HTTPClient{
|
||||
options: newDefaultOption(),
|
||||
tpsLimiter: globalTPSLimiter,
|
||||
@@ -52,6 +85,12 @@ func NewClient(opts ...Option) Client {
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *HTTPClient) Apply(opts ...Option) {
|
||||
for _, o := range opts {
|
||||
o.apply(c.options)
|
||||
}
|
||||
}
|
||||
|
||||
// Request 发送HTTP请求
|
||||
func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
|
||||
// 应用额外设置
|
||||
@@ -63,7 +102,14 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
|
||||
}
|
||||
|
||||
// 创建请求客户端
|
||||
client := &http.Client{Timeout: options.timeout}
|
||||
client := &http.Client{
|
||||
Timeout: options.timeout,
|
||||
Jar: options.cookieJar,
|
||||
}
|
||||
|
||||
if options.transport != nil {
|
||||
client.Transport = options.transport
|
||||
}
|
||||
|
||||
// size为0时将body设为nil
|
||||
if options.contentLength == 0 {
|
||||
@@ -86,6 +132,7 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
|
||||
req *http.Request
|
||||
err error
|
||||
)
|
||||
start := time.Now()
|
||||
if options.ctx != nil {
|
||||
req, err = http.NewRequestWithContext(options.ctx, method, target, body)
|
||||
} else {
|
||||
@@ -102,14 +149,21 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
|
||||
}
|
||||
}
|
||||
|
||||
if options.masterMeta && conf.SystemConfig.Mode == "master" {
|
||||
req.Header.Add(auth.CrHeaderPrefix+"Site-Url", model.GetSiteURL().String())
|
||||
req.Header.Add(auth.CrHeaderPrefix+"Site-Id", model.GetSettingByName("siteID"))
|
||||
req.Header.Add(auth.CrHeaderPrefix+"Cloudreve-Version", conf.BackendVersion)
|
||||
req.Header.Set("User-Agent", "Cloudreve/"+constants.BackendVersion)
|
||||
|
||||
if options.ctx != nil && options.withCorrelationID {
|
||||
req.Header.Add(CorrelationHeader, logging.CorrelationID(options.ctx).String())
|
||||
}
|
||||
|
||||
if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" {
|
||||
req.Header.Add(auth.CrHeaderPrefix+"Node-Id", options.slaveNodeID)
|
||||
mode := c.config.System().Mode
|
||||
if options.masterMeta && mode == conf.MasterMode {
|
||||
req.Header.Add(SiteURLHeader, options.siteURL)
|
||||
req.Header.Add(SiteIDHeader, options.siteID)
|
||||
req.Header.Add(SiteVersionHeader, constants.BackendVersion)
|
||||
}
|
||||
|
||||
if options.slaveNodeID > 0 {
|
||||
req.Header.Add(SlaveNodeIDHeader, strconv.Itoa(options.slaveNodeID))
|
||||
}
|
||||
|
||||
if options.contentLength != -1 {
|
||||
@@ -118,11 +172,16 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
|
||||
|
||||
// 签名请求
|
||||
if options.sign != nil {
|
||||
ctx := options.ctx
|
||||
if options.ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
expire := time.Now().Add(time.Second * time.Duration(options.signTTL))
|
||||
switch method {
|
||||
case "PUT", "POST", "PATCH":
|
||||
auth.SignRequest(options.sign, req, options.signTTL)
|
||||
auth.SignRequest(ctx, options.sign, req, &expire)
|
||||
default:
|
||||
if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil {
|
||||
if resURL, err := auth.SignURI(ctx, options.sign, req.URL.String(), &expire); err == nil {
|
||||
req.URL = resURL
|
||||
}
|
||||
}
|
||||
@@ -134,11 +193,32 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
|
||||
// Logging request
|
||||
if options.logger != nil {
|
||||
statusCode := 0
|
||||
errStr := ""
|
||||
if resp != nil {
|
||||
statusCode = resp.StatusCode
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
logging.Request(options.logger, false, statusCode, req.Method, LocalIP, req.URL.String(), errStr, start)
|
||||
}
|
||||
|
||||
// Apply cookies
|
||||
if resp != nil && resp.Cookies() != nil && options.cookieJar != nil {
|
||||
options.cookieJar.SetCookies(req.URL, resp.Cookies())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return &Response{Err: err}
|
||||
}
|
||||
|
||||
return &Response{Err: nil, Response: resp}
|
||||
return &Response{Err: nil, Response: resp, l: options.logger}
|
||||
}
|
||||
|
||||
// GetResponse 检查响应并获取响应正文
|
||||
@@ -146,21 +226,33 @@ func (resp *Response) GetResponse() (string, error) {
|
||||
if resp.Err != nil {
|
||||
return "", resp.Err
|
||||
}
|
||||
respBody, err := ioutil.ReadAll(resp.Response.Body)
|
||||
respBody, err := io.ReadAll(resp.Response.Body)
|
||||
_ = resp.Response.Body.Close()
|
||||
|
||||
return string(respBody), err
|
||||
}
|
||||
|
||||
// GetResponseIgnoreErr 获取响应正文
|
||||
func (resp *Response) GetResponseIgnoreErr() (string, error) {
|
||||
if resp.Response == nil {
|
||||
return "", resp.Err
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Response.Body)
|
||||
_ = resp.Response.Body.Close()
|
||||
|
||||
return string(respBody), resp.Err
|
||||
}
|
||||
|
||||
// CheckHTTPResponse 检查请求响应HTTP状态码
|
||||
func (resp *Response) CheckHTTPResponse(status int) *Response {
|
||||
func (resp *Response) CheckHTTPResponse(status ...int) *Response {
|
||||
if resp.Err != nil {
|
||||
return resp
|
||||
}
|
||||
|
||||
// 检查HTTP状态码
|
||||
if resp.Response.StatusCode != status {
|
||||
resp.Err = fmt.Errorf("服务器返回非正常HTTP状态%d", resp.Response.StatusCode)
|
||||
if !lo.Contains(status, resp.Response.StatusCode) {
|
||||
resp.Err = fmt.Errorf("Remote returns unexpected status code: %d", resp.Response.StatusCode)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -179,7 +271,10 @@ func (resp *Response) DecodeResponse() (*serializer.Response, error) {
|
||||
var res serializer.Response
|
||||
err = json.Unmarshal([]byte(respString), &res)
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to parse response: %s", string(respString))
|
||||
if resp.l != nil {
|
||||
resp.l.Debug("Failed to parse response: %s", respString)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
@@ -254,10 +349,3 @@ func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
|
||||
}
|
||||
|
||||
// BlackHole 将客户端发来的数据放入黑洞
|
||||
func BlackHole(r io.Reader) {
|
||||
if !model.IsTrueVal(model.GetSettingByName("reset_after_upload_failed")) {
|
||||
io.Copy(ioutil.Discard, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,17 +3,16 @@ package request
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type ClientMock struct {
|
||||
@@ -55,7 +54,7 @@ func TestWithContext(t *testing.T) {
|
||||
|
||||
func TestHTTPClient_Request(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
client := NewClient(WithSlaveMeta("test"))
|
||||
client := NewClientDeprecated(WithSlaveMeta("test"))
|
||||
|
||||
// 正常
|
||||
{
|
||||
@@ -63,8 +62,6 @@ func TestHTTPClient_Request(t *testing.T) {
|
||||
"POST",
|
||||
"/test",
|
||||
strings.NewReader(""),
|
||||
WithContentLength(0),
|
||||
WithEndpoint("http://cloudreveisnotexist.com"),
|
||||
WithTimeout(time.Duration(1)*time.Microsecond),
|
||||
WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10),
|
||||
WithoutHeader([]string{"origin", "origin"}),
|
||||
@@ -79,11 +76,11 @@ func TestHTTPClient_Request(t *testing.T) {
|
||||
"GET",
|
||||
"http://cloudreveisnotexist.com",
|
||||
strings.NewReader(""),
|
||||
WithContentLength(0),
|
||||
WithEndpoint("http://cloudreveisnotexist.com"),
|
||||
WithTimeout(time.Duration(1)*time.Microsecond),
|
||||
WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10),
|
||||
WithContext(context.Background()),
|
||||
WithoutHeader([]string{"s s", "s s"}),
|
||||
WithMasterMeta(),
|
||||
)
|
||||
asserts.Error(resp.Err)
|
||||
asserts.Nil(resp.Response)
|
||||
@@ -241,7 +238,7 @@ func TestBlackHole(t *testing.T) {
|
||||
|
||||
func TestHTTPClient_TPSLimit(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
client := NewClient()
|
||||
client := NewClientDeprecated()
|
||||
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
62
pkg/request/utils.go
Normal file
62
pkg/request/utils.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var contentLengthHeaders = []string{
|
||||
"Content-Length",
|
||||
"X-Expected-Entity-Length", // DavFS on MacOS
|
||||
}
|
||||
|
||||
// BlackHole 将客户端发来的数据放入黑洞
|
||||
func BlackHole(r io.Reader) {
|
||||
io.Copy(io.Discard, r)
|
||||
}
|
||||
|
||||
// SniffContentLength tries to get the content length from the request. It also returns
|
||||
// a reader that will limit to the sniffed content length.
|
||||
func SniffContentLength(r *http.Request) (LimitReaderCloser, int64, error) {
|
||||
for _, header := range contentLengthHeaders {
|
||||
if length := r.Header.Get(header); length != "" {
|
||||
res, err := strconv.ParseInt(length, 10, 64)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return newLimitReaderCloser(r.Body, res), res, nil
|
||||
}
|
||||
}
|
||||
return newLimitReaderCloser(r.Body, 0), 0, nil
|
||||
}
|
||||
|
||||
type LimitReaderCloser interface {
|
||||
io.Reader
|
||||
io.Closer
|
||||
Count() int64
|
||||
}
|
||||
|
||||
type limitReaderCloser struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
read int64
|
||||
}
|
||||
|
||||
func newLimitReaderCloser(r io.ReadCloser, limit int64) LimitReaderCloser {
|
||||
return &limitReaderCloser{
|
||||
Reader: io.LimitReader(r, limit),
|
||||
Closer: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *limitReaderCloser) Read(p []byte) (n int, err error) {
|
||||
n, err = l.Reader.Read(p)
|
||||
l.read += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (l *limitReaderCloser) Count() int64 {
|
||||
return l.read
|
||||
}
|
||||
Reference in New Issue
Block a user