Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes HttpClient Content.Source reads from arbitrary threads #12203

Merged
merged 8 commits into from
Aug 30, 2024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should have the SerializedInvoker assertions of #12143

Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public abstract class HttpReceiver
{
private static final Logger LOG = LoggerFactory.getLogger(HttpReceiver.class);

private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(HttpReceiver.class);
private final HttpChannel channel;
private ResponseState responseState = ResponseState.IDLE;
private NotifiableContentSource contentSource;
Expand Down Expand Up @@ -731,6 +731,7 @@ public void onDataAvailable()
{
if (LOG.isDebugEnabled())
LOG.debug("onDataAvailable on {}", this);
invoker.assertCurrentThreadInvoking();
// The onDataAvailable() method is only ever called
// by the invoker so avoid using the invoker again.
invokeDemandCallback(false);
Expand All @@ -755,6 +756,8 @@ private void processDemand()
if (LOG.isDebugEnabled())
LOG.debug("Processing demand on {}", this);

invoker.assertCurrentThreadInvoking();

Content.Chunk current;
try (AutoLock ignored = lock.lock())
{
Expand Down Expand Up @@ -794,9 +797,14 @@ private void invokeDemandCallback(boolean invoke)
try
{
if (invoke)
{
invoker.run(demandCallback);
}
else
{
invoker.assertCurrentThreadInvoking();
demandCallback.run();
}
}
catch (Throwable x)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ public String toString()
public abstract static class AbstractContentSource implements Content.Source, Closeable
{
private final AutoLock lock = new AutoLock();
private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(AbstractContentSource.class);
private final Queue<Part> parts = new ArrayDeque<>();
private final String boundary;
private final ByteBuffer firstBoundary;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public String toString()
};

private final AutoLock.WithCondition lock = new AutoLock.WithCondition();
private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(AsyncContent.class);
private final Queue<Content.Chunk> chunks = new ArrayDeque<>();
private Content.Chunk persistentFailure;
private boolean readClosed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
public class ByteBufferContentSource implements Content.Source
{
private final AutoLock lock = new AutoLock();
private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(ByteBufferContentSource.class);
private final long length;
private final Collection<ByteBuffer> byteBuffers;
private Iterator<ByteBuffer> iterator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
public class ChunksContentSource implements Content.Source
{
private final AutoLock lock = new AutoLock();
private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(ChunksContentSource.class);
private final long length;
private final Collection<Content.Chunk> chunks;
private Iterator<Content.Chunk> iterator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public abstract class ContentSourceTransformer implements Content.Source

protected ContentSourceTransformer(Content.Source rawSource)
{
this(rawSource, new SerializedInvoker());
this(rawSource, new SerializedInvoker(ContentSourceTransformer.class));
}

protected ContentSourceTransformer(Content.Source rawSource, SerializedInvoker invoker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
public class InputStreamContentSource implements Content.Source
{
private final AutoLock lock = new AutoLock();
private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(InputStreamContentSource.class);
private final InputStream inputStream;
private ByteBufferPool.Sized bufferPool;
private Runnable demandCallback;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class PathContentSource implements Content.Source
// TODO in 12.1.x reimplement this class based on ByteChannelContentSource

private final AutoLock lock = new AutoLock();
private final SerializedInvoker invoker = new SerializedInvoker();
private final SerializedInvoker invoker = new SerializedInvoker(PathContentSource.class);
private final Path path;
private final long length;
private final ByteBufferPool byteBufferPool;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
public class ByteChannelContentSource implements Content.Source
{
private final AutoLock lock = new AutoLock();
private final SerializedInvoker _invoker = new SerializedInvoker();
private final SerializedInvoker _invoker = new SerializedInvoker(ByteChannelContentSource.class);
private final ByteBufferPool.Sized _byteBufferPool;
private ByteChannel _byteChannel;
private final long _offset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ public HttpChannelState(ConnectionMetaData connectionMetaData)
{
_connectionMetaData = connectionMetaData;
// The SerializedInvoker is used to prevent infinite recursion of callbacks calling methods calling callbacks etc.
_readInvoker = new HttpChannelSerializedInvoker();
_writeInvoker = new HttpChannelSerializedInvoker();
_readInvoker = new HttpChannelSerializedInvoker(HttpChannelState.class.getSimpleName() + "_readInvoker");
_writeInvoker = new HttpChannelSerializedInvoker(HttpChannelState.class.getSimpleName() + "_writeInvoker");
}

@Override
Expand Down Expand Up @@ -1812,6 +1812,11 @@ private void completing()

private class HttpChannelSerializedInvoker extends SerializedInvoker
{
public HttpChannelSerializedInvoker(String name)
{
super(name);
}

@Override
protected void onError(Runnable task, Throwable failure)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
*/
public class SerializedExecutor implements Executor
{
private final SerializedInvoker _invoker = new SerializedInvoker()
private final SerializedInvoker _invoker = new SerializedInvoker(SerializedExecutor.class)
{
@Override
protected void onError(Runnable task, Throwable t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

package org.eclipse.jetty.util.thread;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.concurrent.atomic.AtomicReference;

import org.eclipse.jetty.util.component.Dumpable;
import org.eclipse.jetty.util.thread.Invocable.InvocationType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,6 +39,51 @@ public class SerializedInvoker
private static final Logger LOG = LoggerFactory.getLogger(SerializedInvoker.class);

private final AtomicReference<Link> _tail = new AtomicReference<>();
private final String _name;
private volatile Thread _invokerThread;

/**
* Create a new instance whose name is {@code anonymous}.
*/
public SerializedInvoker()
{
this("anonymous");
}

/**
* Create a new instance whose name is derived from the given class.
* @param nameFrom the class to use as a name.
*/
public SerializedInvoker(Class<?> nameFrom)
{
this(nameFrom.getSimpleName());
}

/**
* Create a new instance with the given name.
* @param name the name.
*/
public SerializedInvoker(String name)
{
_name = name;
}

/**
* @return whether the current thread is currently executing a task using this invoker
*/
boolean isCurrentThreadInvoking()
{
return _invokerThread == Thread.currentThread();
}

/**
* @throws IllegalStateException when the current thread is not currently executing a task using this invoker
*/
public void assertCurrentThreadInvoking() throws IllegalStateException
{
if (!isCurrentThreadInvoking())
throw new IllegalStateException();
}

/**
* Arrange for a task to be invoked, mutually excluded from other tasks.
Expand All @@ -59,7 +108,7 @@ public Runnable offer(Runnable task)
{
// Wrap the given task with another one that's going to delegate run() to the wrapped task while the
// wrapper's toString() returns a description of the place in code where SerializedInvoker.run() was called.
task = new NamedRunnable(task, deriveTaskName(task));
task = new NamedRunnable(task);
}
}
Link link = new Link(task);
Expand All @@ -72,18 +121,6 @@ public Runnable offer(Runnable task)
return null;
}

protected String deriveTaskName(Runnable task)
{
StackTraceElement[] stackTrace = new Exception().getStackTrace();
for (StackTraceElement stackTraceElement : stackTrace)
{
String className = stackTraceElement.getClassName();
if (!className.equals(SerializedInvoker.class.getName()) && !className.equals(getClass().getName()))
return "Queued at " + stackTraceElement;
}
return task.toString();
}

/**
* Arrange for tasks to be invoked, mutually excluded from other tasks.
* @param tasks The tasks to invoke
Expand Down Expand Up @@ -116,8 +153,8 @@ public void run(Runnable task)
if (todo != null)
todo.run();
else
if (LOG.isDebugEnabled())
LOG.debug("Queued link in {}", this);
if (LOG.isDebugEnabled())
lorban marked this conversation as resolved.
Show resolved Hide resolved
LOG.debug("Queued link in {}", this);
}

/**
Expand All @@ -131,22 +168,22 @@ public void run(Runnable... tasks)
if (todo != null)
todo.run();
else
if (LOG.isDebugEnabled())
LOG.debug("Queued links in {}", this);
if (LOG.isDebugEnabled())
lorban marked this conversation as resolved.
Show resolved Hide resolved
LOG.debug("Queued links in {}", this);
}

@Override
public String toString()
{
return String.format("%s@%x{tail=%s}", getClass().getSimpleName(), hashCode(), _tail);
return String.format("%s@%x{name=%s,tail=%s,invoker=%s}", getClass().getSimpleName(), hashCode(), _name, _tail, _invokerThread);
}

protected void onError(Runnable task, Throwable t)
{
LOG.warn("Serialized invocation error", t);
}

private class Link implements Runnable, Invocable
private class Link implements Runnable, Invocable, Dumpable
{
private final Runnable _task;
private final AtomicReference<Link> _next = new AtomicReference<>();
Expand All @@ -156,6 +193,24 @@ public Link(Runnable task)
_task = task;
}

@Override
public void dump(Appendable out, String indent) throws IOException
{
if (_task instanceof NamedRunnable nr)
{
StringWriter sw = new StringWriter();
nr.stack.printStackTrace(new PrintWriter(sw));
Dumpable.dumpObjects(out, indent, nr.toString(), sw.toString());
}
else
{
Dumpable.dumpObjects(out, indent, _task);
}
Link link = _next.get();
if (link != null)
link.dump(out, indent);
}

@Override
public InvocationType getInvocationType()
{
Expand Down Expand Up @@ -186,6 +241,7 @@ public void run()
{
if (LOG.isDebugEnabled())
LOG.debug("Running link {} of {}", link, SerializedInvoker.this);
_invokerThread = Thread.currentThread();
try
{
link._task.run();
Expand All @@ -196,6 +252,12 @@ public void run()
LOG.debug("Failed while running link {} of {}", link, SerializedInvoker.this, t);
onError(link._task, t);
}
finally
{
// _invokerThread must be nulled before calling link.next() as
// once the latter has executed, another thread can enter Link.run().
_invokerThread = null;
}
link = link.next();
if (link == null && LOG.isDebugEnabled())
LOG.debug("Next link is null, execution is over in {}", SerializedInvoker.this);
Expand All @@ -209,10 +271,35 @@ public String toString()
}
}

private record NamedRunnable(Runnable delegate, String name) implements Runnable
private class NamedRunnable implements Runnable
{
private static final Logger LOG = LoggerFactory.getLogger(NamedRunnable.class);

private final Runnable delegate;
private final String name;
private final Throwable stack;

public NamedRunnable(Runnable delegate)
lorban marked this conversation as resolved.
Show resolved Hide resolved
{
this.delegate = delegate;
this.stack = new Throwable();
this.name = deriveTaskName(delegate, stack);
}

protected String deriveTaskName(Runnable task, Throwable stack)
lorban marked this conversation as resolved.
Show resolved Hide resolved
{
StackTraceElement[] stackTrace = stack.getStackTrace();
for (StackTraceElement stackTraceElement : stackTrace)
{
String className = stackTraceElement.getClassName();
if (!className.equals(SerializedInvoker.class.getName()) &&
!className.equals(SerializedInvoker.this.getClass().getName()) &&
!className.equals(getClass().getName()))
return "Queued by " + Thread.currentThread().getName() + " at " + stackTraceElement;
}
return task.toString();
}

@Override
public void run()
{
Expand Down
Loading
Loading