Skip to content

Commit

Permalink
simplify the inactivity timeout
Browse files Browse the repository at this point in the history
Signed-off-by: Avi Deitcher <avi@deitcher.net>
  • Loading branch information
deitch authored and eriknordmark committed Sep 20, 2023
1 parent 127266b commit 08a38c8
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 12 deletions.
22 changes: 10 additions & 12 deletions zedUpload/httputil/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,7 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin
copiedSize = 0
forceRestart = false
}
// we need innerCtx cancel to call in case of inactivity
innerCtx, innerCtxCancel := context.WithCancel(ctx)
inactivityTimer := time.AfterFunc(inactivityTimeout, func() {
//keep it to call cancel regardless of logic to releases resources
innerCtxCancel()
})
defer inactivityTimer.Stop()
req, err := http.NewRequestWithContext(innerCtx, http.MethodGet, host, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host, nil)
if err != nil {
appendToErrorList("request failed for get %s: %s", host, err)
return stats, Resp{}
Expand All @@ -276,6 +269,10 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin
withRange = true
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", copiedSize))
}
// set the inactivity timer just for retrieving the headers
if transport, ok := client.Transport.(*http.Transport); ok {
transport.ResponseHeaderTimeout = inactivityTimeout
}
resp, err := client.Do(req)
if err != nil {
// break the retries loop and skip the error from
Expand Down Expand Up @@ -317,12 +314,13 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin
rsp.BodyLength = int(resp.ContentLength)
}
var written int64

// use the inactivityReader to trigger failure for the timeouts
inactivityReader := NewTimeoutReader(inactivityTimeout, resp.Body)
for {
var copyErr error

// reset the timer for each read
inactivityTimer.Reset(inactivityTimeout)
written, copyErr = io.CopyN(local, resp.Body, chunkSize)
written, copyErr = io.CopyN(local, inactivityReader, chunkSize)
copiedSize += written
stats.Asize = copiedSize

Expand All @@ -339,7 +337,7 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin
// empty out the error list
errorList = nil
return stats, rsp
case copyErr != nil && innerCtx.Err() != nil:
case copyErr != nil && errors.Is(copyErr, &ErrTimeout{}):
// the error comes from timeout
appendToErrorList("inactivity for %s", inactivityTimeout)
case copyErr != nil:
Expand Down
55 changes: 55 additions & 0 deletions zedUpload/httputil/timeoutreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package http

import (
"io"
"time"
)

// TimeoutReader reads until a preset timeout,
// then returns a timeout error.
// The timeout is for each read.
type TimeoutReader struct {
timeout time.Duration
reader io.Reader
}

// NewTimeoutReader creates a new TimeoutReader.
func NewTimeoutReader(timeout time.Duration, r io.Reader) *TimeoutReader {
return &TimeoutReader{timeout, r}
}

// Read reads from the underlying reader.
func (r *TimeoutReader) Read(p []byte) (int, error) {
// channel is just used to signal when the read is done
var (
n int
err error
)
c := make(chan byte, 1)
// we have to put this in a goroutine, so we do not block our main routine
// waiting on it
go func() {
n, err = r.reader.Read(p)
c <- 0
}()
select {
case <-c:
return n, err
case <-time.After(r.timeout):
return 0, &ErrTimeout{}
}
}

type ErrTimeout struct {
timeout time.Duration
}

func (e *ErrTimeout) Error() string {
return e.timeout.String()
}

// Is the other error the same type? We do not really care about the properties
func (e *ErrTimeout) Is(err error) bool {
_, ok := err.(*ErrTimeout)
return ok
}
48 changes: 48 additions & 0 deletions zedUpload/httputil/timeoutreader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package http

import (
"errors"
"io"
"strings"
"testing"
"time"
)

func TestTimeoutReaderSucceed(t *testing.T) {
// Test that TimeoutReader succeeds when the underlying reader succeeds.
r := NewTimeoutReader(1*time.Second, strings.NewReader("hello world"))
p := make([]byte, 11)
n, err := r.Read(p)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if n != 11 {
t.Errorf("unexpected read size: %d", n)
}
if string(p) != "hello world" {
t.Errorf("unexpected read: %q", string(p))
}
}

func TestTimeoutReaderFail(t *testing.T) {
// Test that TimeoutReader fails when the underlying reader fails.
r := NewTimeoutReader(1*time.Second, blockedReader{delay: 2 * time.Second, reader: strings.NewReader("hello world")})
p := make([]byte, 12)
_, err := r.Read(p)
if err == nil {
t.Errorf("unexpected success")
}
if !errors.Is(err, &ErrTimeout{}) {
t.Errorf("error was %v and not ErrTimeout", err)
}
}

type blockedReader struct {
delay time.Duration
reader io.Reader
}

func (r blockedReader) Read(p []byte) (n int, err error) {
time.Sleep(r.delay)
return r.reader.Read(p)
}

0 comments on commit 08a38c8

Please sign in to comment.