更新:从jetty-9.4.15.v20190215开始,对于端口统一支持已内置于Jetty中;请参考此答案。
是的,我们可以
这是可能的,而且我们已经做到了。这里的代码适用于Jetty 8;我尚未测试过Jetty 9,但此答案包含了适用于Jetty 9的类似代码。
顺便说一下,这被称为端口统一,在使用Glassfish和Grizzly的情况下,它显然已经被支持。
概述
基本思路是创建一个实现了org.eclipse.jetty.server.Connector
接口的处理器,该处理器可以预先查看客户端请求的第一个字节。幸运的是,HTTP和HTTPS都要求客户端启动通信。对于HTTPS(以及TLS/SSL一般情况),第一个字节将是0x16
(TLS)或>= 0x80
(SSLv2)。对于HTTP,第一个字节将是良好的可打印7位ASCII码。现在,根据第一个字节,处理器将生成SSL连接或普通连接。
在这里的代码中,我们利用了Jetty的SslSelectChannelConnector
本身扩展了SelectChannelConnector
并具有newPlainConnection()
方法(调用其超类以生成非SSL连接)和newConnection()
方法(生成SSL连接)。因此,我们新的Connector
可以扩展SslSelectChannelConnector
并在观察到客户端的第一个字节后委托给其中的一个方法。
不幸的是,在第一个字节可用之前,我们将需要创建AsyncConnection
的实例。AsyncConnection
的某些方法甚至可能在第一个字节可用之前被调用。因此,我们创建了一个LazyConnection 实现 AsyncConnection
,它可以稍后确定要委托哪种类型的连接,或者在不知道之前返回合理的默认响应。
由于基于NIO,所以我们的Connector
将使用SocketChannel
工作。幸运的是,我们可以扩展SocketChannel
以创建一个ReadAheadSocketChannelWrapper
,该封装器委托给“真实”的SocketChannel
,但可以检查并存储客户端消息的前几个字节。
一些细节
有一个非常巧妙的部分。我们的Connector
必须重写的方法之一是customize(Endpoint,Request)
。如果我们最终使用基于SSL的Endpoint
,我们可以直接传递给我们的超类;否则,超类将抛出ClassCastException
,但只有在传递给它的超类并且设置Request
的方案后才会抛出异常。因此,我们传递给超类,但在看到异常时撤消设置方案。
我们还覆盖了isConfidential()
和isIntegral()
,以确保我们的servlet可以正确使用HttpServletRequest.isSecure()
来确定是使用HTTP还是HTTPS。
尝试从客户端读取第一个字节可能会抛出IOException
,但我们可能必须在不期望IOException
的地方尝试,这种情况下我们将保留异常并稍后抛出。
在Java >= 7和Java 6中,扩展SocketChannel
看起来不同。在后者的情况下,只需注释掉Java 6 SocketChannel
没有的方法即可。
代码
public class PortUnificationSelectChannelConnector extends SslSelectChannelConnector {
public PortUnificationSelectChannelConnector() {
super();
}
public PortUnificationSelectChannelConnector(SslContextFactory sslContextFactory) {
super(sslContextFactory);
}
@Override
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, SelectionKey key) throws IOException {
return super.newEndPoint(new ReadAheadSocketChannelWrapper(channel, 1), selectSet, key);
}
@Override
protected AsyncConnection newConnection(SocketChannel channel, AsyncEndPoint endPoint) {
return new LazyConnection((ReadAheadSocketChannelWrapper)channel, endPoint);
}
@Override
public void customize(EndPoint endpoint, Request request) throws IOException {
String scheme = request.getScheme();
try {
super.customize(endpoint, request);
} catch (ClassCastException e) {
request.setScheme(scheme);
}
}
@Override
public boolean isConfidential(Request request) {
if (request.getAttribute("javax.servlet.request.cipher_suite") != null) return true;
else return isForwarded() && request.getScheme().equalsIgnoreCase(HttpSchemes.HTTPS);
}
@Override
public boolean isIntegral(Request request) {
return isConfidential(request);
}
class LazyConnection implements AsyncConnection {
private final ReadAheadSocketChannelWrapper channel;
private final AsyncEndPoint endPoint;
private final long timestamp;
private AsyncConnection connection;
public LazyConnection(ReadAheadSocketChannelWrapper channel, AsyncEndPoint endPoint) {
this.channel = channel;
this.endPoint = endPoint;
this.timestamp = System.currentTimeMillis();
this.connection = determineNewConnection(channel, endPoint, false);
}
public Connection handle() throws IOException {
if (connection == null) {
connection = determineNewConnection(channel, endPoint, false);
channel.throwPendingException();
}
if (connection != null) return connection.handle();
else return this;
}
public long getTimeStamp() {
return timestamp;
}
public void onInputShutdown() throws IOException {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onInputShutdown();
}
public boolean isIdle() {
if (connection == null) connection = determineNewConnection(channel, endPoint, false);
if (connection != null) return connection.isIdle();
else return false;
}
public boolean isSuspended() {
if (connection == null) connection = determineNewConnection(channel, endPoint, false);
if (connection != null) return connection.isSuspended();
else return false;
}
public void onClose() {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onClose();
}
public void onIdleExpired(long l) {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onIdleExpired(l);
}
AsyncConnection determineNewConnection(ReadAheadSocketChannelWrapper channel, AsyncEndPoint endPoint, boolean force) {
byte[] bytes = channel.getBytes();
if ((bytes == null || bytes.length == 0) && !force) return null;
if (looksLikeSsl(bytes)) {
return PortUnificationSelectChannelConnector.super.newConnection(channel, endPoint);
} else {
return PortUnificationSelectChannelConnector.super.newPlainConnection(channel, endPoint);
}
}
private boolean looksLikeSsl(byte[] bytes) {
if (bytes == null || bytes.length == 0) return false;
byte b = bytes[0];
return b >= 0x7F || (b < 0x20 && b != '\n' && b != '\r' && b != '\t');
}
}
static class ReadAheadSocketChannelWrapper extends SocketChannel {
private final SocketChannel channel;
private final ByteBuffer start;
private byte[] bytes;
private IOException pendingException;
private int leftToRead;
public ReadAheadSocketChannelWrapper(SocketChannel channel, int readAheadLength) throws IOException {
super(channel.provider());
this.channel = channel;
start = ByteBuffer.allocate(readAheadLength);
leftToRead = readAheadLength;
readAhead();
}
public synchronized void readAhead() throws IOException {
if (leftToRead > 0) {
int n = channel.read(start);
if (n == -1) {
leftToRead = -1;
} else {
leftToRead -= n;
}
if (leftToRead <= 0) {
start.flip();
bytes = new byte[start.remaining()];
start.get(bytes);
start.rewind();
}
}
}
public byte[] getBytes() {
if (pendingException == null) {
try {
readAhead();
} catch (IOException e) {
pendingException = e;
}
}
return bytes;
}
public void throwPendingException() throws IOException {
if (pendingException != null) {
IOException e = pendingException;
pendingException = null;
throw e;
}
}
private int readFromStart(ByteBuffer dst) throws IOException {
int sr = start.remaining();
int dr = dst.remaining();
if (dr == 0) return 0;
int n = Math.min(dr, sr);
dst.put(bytes, start.position(), n);
start.position(start.position() + n);
return n;
}
public synchronized int read(ByteBuffer dst) throws IOException {
throwPendingException();
readAhead();
if (leftToRead > 0) return 0;
int sr = start.remaining();
if (sr > 0) {
int n = readFromStart(dst);
if (n < sr) return n;
}
return sr + channel.read(dst);
}
public synchronized long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
throwPendingException();
if (offset + length > dsts.length || length < 0 || offset < 0) {
throw new IndexOutOfBoundsException();
}
readAhead();
if (leftToRead > 0) return 0;
int sr = start.remaining();
int newOffset = offset;
if (sr > 0) {
int accum = 0;
for (; newOffset < offset + length; newOffset++) {
accum += readFromStart(dsts[newOffset]);
if (accum == sr) break;
}
if (accum < sr) return accum;
}
return sr + channel.read(dsts, newOffset, length - newOffset + offset);
}
public int hashCode() {
return channel.hashCode();
}
public boolean equals(Object obj) {
return channel.equals(obj);
}
public String toString() {
return channel.toString();
}
public Socket socket() {
return channel.socket();
}
public boolean isConnected() {
return channel.isConnected();
}
public boolean isConnectionPending() {
return channel.isConnectionPending();
}
public boolean connect(SocketAddress remote) throws IOException {
return channel.connect(remote);
}
public boolean finishConnect() throws IOException {
return channel.finishConnect();
}
public int write(ByteBuffer src) throws IOException {
return channel.write(src);
}
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
return channel.write(srcs, offset, length);
}
@Override
protected void implCloseSelectableChannel() throws IOException {
channel.close();
}
@Override
protected void implConfigureBlocking(boolean block) throws IOException {
channel.configureBlocking(block);
}
}
}