From 08a38c83afefbec7854d92fbc5419e53dea01af5 Mon Sep 17 00:00:00 2001 From: Avi Deitcher Date: Tue, 19 Sep 2023 20:39:43 +0300 Subject: [PATCH] simplify the inactivity timeout Signed-off-by: Avi Deitcher --- zedUpload/httputil/http.go | 22 +++++----- zedUpload/httputil/timeoutreader.go | 55 ++++++++++++++++++++++++ zedUpload/httputil/timeoutreader_test.go | 48 +++++++++++++++++++++ 3 files changed, 113 insertions(+), 12 deletions(-) create mode 100644 zedUpload/httputil/timeoutreader.go create mode 100644 zedUpload/httputil/timeoutreader_test.go diff --git a/zedUpload/httputil/http.go b/zedUpload/httputil/http.go index 9b0a7ae..43a278f 100644 --- a/zedUpload/httputil/http.go +++ b/zedUpload/httputil/http.go @@ -255,14 +255,7 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin copiedSize = 0 forceRestart = false } - // we need innerCtx cancel to call in case of inactivity - innerCtx, innerCtxCancel := context.WithCancel(ctx) - inactivityTimer := time.AfterFunc(inactivityTimeout, func() { - //keep it to call cancel regardless of logic to releases resources - innerCtxCancel() - }) - defer inactivityTimer.Stop() - req, err := http.NewRequestWithContext(innerCtx, http.MethodGet, host, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, host, nil) if err != nil { appendToErrorList("request failed for get %s: %s", host, err) return stats, Resp{} @@ -276,6 +269,10 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin withRange = true req.Header.Set("Range", fmt.Sprintf("bytes=%d-", copiedSize)) } + // set the inactivity timer just for retrieving the headers + if transport, ok := client.Transport.(*http.Transport); ok { + transport.ResponseHeaderTimeout = inactivityTimeout + } resp, err := client.Do(req) if err != nil { // break the retries loop and skip the error from @@ -317,12 +314,13 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin rsp.BodyLength = int(resp.ContentLength) } var written int64 + + // use the inactivityReader to trigger failure for the timeouts + inactivityReader := NewTimeoutReader(inactivityTimeout, resp.Body) for { var copyErr error - // reset the timer for each read - inactivityTimer.Reset(inactivityTimeout) - written, copyErr = io.CopyN(local, resp.Body, chunkSize) + written, copyErr = io.CopyN(local, inactivityReader, chunkSize) copiedSize += written stats.Asize = copiedSize @@ -339,7 +337,7 @@ func execCmdGet(ctx context.Context, objSize int64, localFile string, host strin // empty out the error list errorList = nil return stats, rsp - case copyErr != nil && innerCtx.Err() != nil: + case copyErr != nil && errors.Is(copyErr, &ErrTimeout{}): // the error comes from timeout appendToErrorList("inactivity for %s", inactivityTimeout) case copyErr != nil: diff --git a/zedUpload/httputil/timeoutreader.go b/zedUpload/httputil/timeoutreader.go new file mode 100644 index 0000000..2859bf8 --- /dev/null +++ b/zedUpload/httputil/timeoutreader.go @@ -0,0 +1,55 @@ +package http + +import ( + "io" + "time" +) + +// TimeoutReader reads until a preset timeout, +// then returns a timeout error. +// The timeout is for each read. +type TimeoutReader struct { + timeout time.Duration + reader io.Reader +} + +// NewTimeoutReader creates a new TimeoutReader. +func NewTimeoutReader(timeout time.Duration, r io.Reader) *TimeoutReader { + return &TimeoutReader{timeout, r} +} + +// Read reads from the underlying reader. +func (r *TimeoutReader) Read(p []byte) (int, error) { + // channel is just used to signal when the read is done + var ( + n int + err error + ) + c := make(chan byte, 1) + // we have to put this in a goroutine, so we do not block our main routine + // waiting on it + go func() { + n, err = r.reader.Read(p) + c <- 0 + }() + select { + case <-c: + return n, err + case <-time.After(r.timeout): + return 0, &ErrTimeout{} + } +} + +type ErrTimeout struct { + timeout time.Duration +} + +func (e *ErrTimeout) Error() string { + return e.timeout.String() +} + +// Is the other error the same type? We do not really care about the properties +func (e *ErrTimeout) Is(err error) bool { + _, ok := err.(*ErrTimeout) + return ok +} diff --git a/zedUpload/httputil/timeoutreader_test.go b/zedUpload/httputil/timeoutreader_test.go new file mode 100644 index 0000000..cefcb49 --- /dev/null +++ b/zedUpload/httputil/timeoutreader_test.go @@ -0,0 +1,48 @@ +package http + +import ( + "errors" + "io" + "strings" + "testing" + "time" +) + +func TestTimeoutReaderSucceed(t *testing.T) { + // Test that TimeoutReader succeeds when the underlying reader succeeds. + r := NewTimeoutReader(1*time.Second, strings.NewReader("hello world")) + p := make([]byte, 11) + n, err := r.Read(p) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if n != 11 { + t.Errorf("unexpected read size: %d", n) + } + if string(p) != "hello world" { + t.Errorf("unexpected read: %q", string(p)) + } +} + +func TestTimeoutReaderFail(t *testing.T) { + // Test that TimeoutReader fails when the underlying reader fails. + r := NewTimeoutReader(1*time.Second, blockedReader{delay: 2 * time.Second, reader: strings.NewReader("hello world")}) + p := make([]byte, 12) + _, err := r.Read(p) + if err == nil { + t.Errorf("unexpected success") + } + if !errors.Is(err, &ErrTimeout{}) { + t.Errorf("error was %v and not ErrTimeout", err) + } +} + +type blockedReader struct { + delay time.Duration + reader io.Reader +} + +func (r blockedReader) Read(p []byte) (n int, err error) { + time.Sleep(r.delay) + return r.reader.Read(p) +}