Skip to content

Commit

Permalink
fix: propagate body stream error to close function (valyala#1743)
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Denushev committed Apr 12, 2024
1 parent d3aa5a1 commit b863648
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2975,12 +2975,12 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error {
hc.releaseReader(br)
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn || resp.ConnectionClose() {
if closeConn || resp.ConnectionClose() || wErr != nil {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
Expand Down
40 changes: 24 additions & 16 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,26 +321,31 @@ func (resp *Response) BodyStream() io.Reader {
}

func (resp *Response) CloseBodyStream() error {
return resp.closeBodyStream()
return resp.closeBodyStream(nil)
}

type ReadCloserWithError interface {
io.Reader
CloseWithError(err error) error
}

type closeReader struct {
io.Reader
closeFunc func() error
closeFunc func(err error) error
}

func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser {
func newCloseReaderWithError(r io.Reader, closeFunc func(err error) error) ReadCloserWithError {
if r == nil {
panic(`BUG: reader is nil`)
}
return &closeReader{Reader: r, closeFunc: closeFunc}
}

func (c *closeReader) Close() error {
func (c *closeReader) CloseWithError(err error) error {
if c.closeFunc == nil {
return nil
}
return c.closeFunc()
return c.closeFunc(err)
}

// BodyWriter returns writer for populating request body.
Expand Down Expand Up @@ -394,7 +399,7 @@ func (resp *Response) Body() []byte {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
Expand Down Expand Up @@ -618,7 +623,7 @@ func (req *Request) BodyWriteTo(w io.Writer) error {
func (resp *Response) BodyWriteTo(w io.Writer) error {
if resp.bodyStream != nil {
_, err := copyZeroAlloc(w, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(err) //nolint:errcheck
return err
}
_, err := w.Write(resp.bodyBytes())
Expand All @@ -629,29 +634,29 @@ func (resp *Response) BodyWriteTo(w io.Writer) error {
//
// It is safe re-using p after the function returns.
func (resp *Response) AppendBody(p []byte) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().Write(p) //nolint:errcheck
}

// AppendBodyString appends s to response body.
func (resp *Response) AppendBodyString(s string) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().WriteString(s) //nolint:errcheck
}

// SetBody sets response body.
//
// It is safe re-using body argument after the function returns.
func (resp *Response) SetBody(body []byte) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.Write(body) //nolint:errcheck
}

// SetBodyString sets response body.
func (resp *Response) SetBodyString(body string) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.WriteString(body) //nolint:errcheck
Expand All @@ -660,7 +665,7 @@ func (resp *Response) SetBodyString(body string) {
// ResetBody resets response body.
func (resp *Response) ResetBody() {
resp.bodyRaw = nil
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
if resp.body != nil {
if resp.keepBodyBuffer {
resp.body.Reset()
Expand Down Expand Up @@ -700,7 +705,7 @@ func (resp *Response) ReleaseBody(size int) {
return
}
if cap(resp.body.B) > size {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
resp.body = nil
}
}
Expand Down Expand Up @@ -734,7 +739,7 @@ func (resp *Response) SwapBody(body []byte) []byte {
if resp.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
Expand Down Expand Up @@ -2061,7 +2066,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error
}
}
}
errc := resp.closeBodyStream()
errc := resp.closeBodyStream(err)
if err == nil {
err = errc
}
Expand All @@ -2083,14 +2088,17 @@ func (req *Request) closeBodyStream() error {
return err
}

func (resp *Response) closeBodyStream() error {
func (resp *Response) closeBodyStream(wErr error) error {
if resp.bodyStream == nil {
return nil
}
var err error
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
if bsc, ok := resp.bodyStream.(ReadCloserWithError); ok {
err = bsc.CloseWithError(wErr)
}
if bsr, ok := resp.bodyStream.(*requestStream); ok {
releaseRequestStream(bsr)
}
Expand Down

0 comments on commit b863648

Please sign in to comment.