Last active
November 30, 2023 09:20
-
-
Save jedvardsson/7ba7bbc94b4951f82da4b590ace725d2 to your computer and use it in GitHub Desktop.
A socket factory for connecting to Postgres via unix domain sockets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.github.jedvardsson.pgsf; | |
import org.postgresql.PGProperty; | |
import org.postgresql.util.PSQLException; | |
import javax.net.SocketFactory; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.io.OutputStream; | |
import java.io.UncheckedIOException; | |
import java.net.*; | |
import java.nio.channels.*; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.util.List; | |
import java.util.Properties; | |
import java.util.Set; | |
// Usage: jdbc:postgresql:///mydatabase?socketFactory=com.github.jedvardsson.pgsf.PostgresSocketFactory</code> | |
public class PostgresSocketFactory extends SocketFactory { | |
private final Properties properties; | |
public PostgresSocketFactory(Properties properties) { | |
this.properties = properties; | |
} | |
@Override | |
public Socket createSocket() throws IOException { | |
try { | |
String host = PGProperty.PG_HOST.getOrNull(properties); | |
int port = PGProperty.PG_PORT.getInt(properties); | |
if (host == null || host.isEmpty()) { | |
String socketFileName = ".s.PGSQL." + port; | |
List<Path> sockets = List.of(Path.of("/tmp", socketFileName), Path.of("/var/run/postgres", socketFileName)); | |
UnixDomainSocketAddress socketAddress = sockets.stream() | |
.filter(Files::exists) | |
.map(UnixDomainSocketAddress::of) | |
.findFirst().orElseThrow(() -> new IllegalStateException("host or unix_socket was not specified and default sockets not found: " + sockets)); | |
SocketChannel socketChannel = SocketChannel.open(StandardProtocolFamily.UNIX); | |
socketChannel.configureBlocking(true); | |
socketChannel.connect(socketAddress); | |
return new SocketAdaptor(socketChannel); | |
} else { | |
SocketFactory.getDefault().createSocket(); | |
} | |
return super.createSocket(); | |
} catch (PSQLException e) { | |
throw new RuntimeException(e); | |
} | |
} | |
@Override | |
public Socket createSocket(String host, int port) throws IOException, UnknownHostException { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException, UnknownHostException { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public Socket createSocket(InetAddress host, int port) throws IOException { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException { | |
throw new UnsupportedOperationException(); | |
} | |
// sun.nio.ch.SocketAdaptor | |
@SuppressWarnings("SameParameterValue") | |
private static class SocketAdaptor extends Socket { | |
private final SocketChannel sc; | |
private volatile boolean inputShutdown; | |
private volatile boolean outputShutdown; | |
private volatile int timeout; | |
public SocketAdaptor(SocketChannel sc) throws SocketException { | |
super(new DummySocketImpl()); | |
this.sc = sc; | |
} | |
@Override | |
public void connect(SocketAddress endpoint) throws IOException { | |
connect(endpoint, 0); | |
} | |
@Override | |
public void connect(SocketAddress endpoint, int timeout) throws IOException { | |
// ignore timeout for now | |
sc.connect(endpoint); | |
// try (Selector selector = Selector.open()) { | |
// boolean connected = sc.connect(endpoint); | |
// if (connected) { | |
// return; | |
// } | |
// | |
// sc.register(selector, SelectionKey.OP_CONNECT); | |
// int selectedKeyCount = selector.select(timeout); | |
// if (selectedKeyCount != 1) { | |
// throw new RuntimeException("Expected 1 selected key: " + selectedKeyCount); | |
// } | |
// | |
// Set<SelectionKey> selectedKeys = selector.selectedKeys(); | |
// Iterator<SelectionKey> iter = selectedKeys.iterator(); | |
// SelectionKey key = iter.next(); | |
// if (!key.isConnectable()) { | |
// throw new IllegalArgumentException("Expected selected key to be connectable: " + key); | |
// } | |
// iter.remove(); | |
// } | |
} | |
@Override | |
public void bind(SocketAddress bindpoint) throws IOException { | |
sc.bind(bindpoint); | |
} | |
@Override | |
public InetAddress getInetAddress() { | |
try { | |
SocketAddress address = sc.getRemoteAddress(); | |
if (address instanceof InetSocketAddress inetSocketAddress) { | |
return inetSocketAddress.getAddress(); | |
} | |
return null; | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
@Override | |
public InetAddress getLocalAddress() { | |
try { | |
SocketAddress address = sc.getLocalAddress(); | |
if (address instanceof InetSocketAddress inetSocketAddress) { | |
return inetSocketAddress.getAddress(); | |
} | |
return null; | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
@Override | |
public int getPort() { | |
try { | |
SocketAddress address = sc.getRemoteAddress(); | |
if (address instanceof InetSocketAddress inetSocketAddress) { | |
return inetSocketAddress.getPort(); | |
} | |
return 0; | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
@Override | |
public int getLocalPort() { | |
try { | |
SocketAddress address = sc.getLocalAddress(); | |
if (address instanceof InetSocketAddress inetSocketAddress) { | |
return inetSocketAddress.getPort(); | |
} | |
return -1; | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
@Override | |
public SocketAddress getRemoteSocketAddress() { | |
try { | |
return sc.getRemoteAddress(); | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
@Override | |
public SocketAddress getLocalSocketAddress() { | |
try { | |
return sc.getLocalAddress(); | |
} catch (IOException e) { | |
throw new UncheckedIOException(e); | |
} | |
} | |
@Override | |
public SocketChannel getChannel() { | |
return sc; | |
} | |
@Override | |
public InputStream getInputStream() throws IOException { | |
if (!sc.isOpen()) | |
throw new SocketException("Socket is closed"); | |
if (!sc.isConnected()) | |
throw new SocketException("Socket is not connected"); | |
if (inputShutdown) | |
throw new SocketException("Socket input is shutdown"); | |
// ignore timeout for now | |
return Channels.newInputStream(sc); | |
} | |
@Override | |
public OutputStream getOutputStream() throws IOException { | |
if (!sc.isOpen()) | |
throw new SocketException("Socket is closed"); | |
if (!sc.isConnected()) | |
throw new SocketException("Socket is not connected"); | |
if (outputShutdown) | |
throw new SocketException("Socket output is shutdown"); | |
// ignore timeout for now | |
return Channels.newOutputStream(sc); | |
} | |
private void setBooleanOption(SocketOption<Boolean> name, boolean value) throws SocketException { | |
try { | |
sc.setOption(name, value); | |
} catch (IOException x) { | |
translateToSocketException(x); | |
} | |
} | |
private void setIntOption(SocketOption<Integer> name, int value) throws SocketException { | |
try { | |
sc.setOption(name, value); | |
} catch (IOException e) { | |
translateToSocketException(e); | |
} | |
} | |
private boolean getBooleanOption(SocketOption<Boolean> name) throws SocketException { | |
try { | |
return sc.getOption(name).booleanValue(); | |
} catch (IOException e) { | |
translateToSocketException(e); | |
return false; // keep compiler happy | |
} | |
} | |
private int getIntOption(SocketOption<Integer> name) throws SocketException { | |
try { | |
return sc.getOption(name).intValue(); | |
} catch (IOException e) { | |
translateToSocketException(e); | |
return -1; // keep compiler happy | |
} | |
} | |
@Override | |
public void setTcpNoDelay(boolean on) throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.TCP_NODELAY)) { | |
setBooleanOption(StandardSocketOptions.TCP_NODELAY, on); | |
} | |
} | |
@Override | |
public boolean getTcpNoDelay() throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.TCP_NODELAY)) { | |
return getBooleanOption(StandardSocketOptions.TCP_NODELAY); | |
} | |
return false; | |
} | |
@Override | |
public void setSoLinger(boolean on, int linger) throws SocketException { | |
if (!on) | |
linger = -1; | |
setIntOption(StandardSocketOptions.SO_LINGER, linger); | |
} | |
@Override | |
public int getSoLinger() throws SocketException { | |
return getIntOption(StandardSocketOptions.SO_LINGER); | |
} | |
@Override | |
public void sendUrgentData(int data) throws IOException { | |
// int n = sc.sendOutOfBandData((byte) data); | |
// if (n == 0) | |
// throw new IOException("Socket buffer full"); | |
throw new SocketException("Urgent data not supported"); | |
} | |
private static final SocketOption<Boolean> SO_OOBINLINE = | |
new SocketOption<Boolean>() { | |
public String name() { | |
return "SO_OOBINLINE"; | |
} | |
public Class<Boolean> type() { | |
return Boolean.class; | |
} | |
public String toString() { | |
return name(); | |
} | |
}; | |
@Override | |
public void setOOBInline(boolean on) throws SocketException { | |
if (sc.supportedOptions().contains(SO_OOBINLINE)) { | |
setBooleanOption(SO_OOBINLINE, on); | |
} | |
} | |
@Override | |
public boolean getOOBInline() throws SocketException { | |
if (sc.supportedOptions().contains(SO_OOBINLINE)) { | |
return getBooleanOption(SO_OOBINLINE); | |
} | |
return false; | |
} | |
@Override | |
public void setSoTimeout(int timeout) throws SocketException { | |
if (!sc.isOpen()) | |
throw new SocketException("Socket is closed"); | |
if (timeout < 0) | |
throw new IllegalArgumentException("timeout < 0"); | |
this.timeout = timeout; | |
} | |
@Override | |
public int getSoTimeout() throws SocketException { | |
if (!sc.isOpen()) | |
throw new SocketException("Socket is closed"); | |
return timeout; | |
} | |
@Override | |
public void setSendBufferSize(int size) throws SocketException { | |
// size 0 valid for SocketChannel, invalid for Socket | |
if (size <= 0) | |
throw new IllegalArgumentException("Invalid send size"); | |
setIntOption(StandardSocketOptions.SO_SNDBUF, size); | |
} | |
@Override | |
public int getSendBufferSize() throws SocketException { | |
return getIntOption(StandardSocketOptions.SO_SNDBUF); | |
} | |
@Override | |
public void setReceiveBufferSize(int size) throws SocketException { | |
// size 0 valid for SocketChannel, invalid for Socket | |
if (size <= 0) | |
throw new IllegalArgumentException("Invalid receive size"); | |
setIntOption(StandardSocketOptions.SO_RCVBUF, size); | |
} | |
@Override | |
public int getReceiveBufferSize() throws SocketException { | |
return getIntOption(StandardSocketOptions.SO_RCVBUF); | |
} | |
@Override | |
public void setKeepAlive(boolean on) throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.SO_KEEPALIVE)) { | |
setBooleanOption(StandardSocketOptions.SO_KEEPALIVE, on); | |
} | |
} | |
@Override | |
public boolean getKeepAlive() throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.SO_KEEPALIVE)) { | |
return getBooleanOption(StandardSocketOptions.SO_KEEPALIVE); | |
} | |
return false; | |
} | |
@Override | |
public void setTrafficClass(int tc) throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.IP_TOS)) { | |
setIntOption(StandardSocketOptions.IP_TOS, tc); | |
} | |
} | |
@Override | |
public int getTrafficClass() throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.IP_TOS)) { | |
return getIntOption(StandardSocketOptions.IP_TOS); | |
} | |
return 0; | |
} | |
@Override | |
public void setReuseAddress(boolean on) throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.SO_REUSEADDR)) { | |
setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on); | |
} | |
} | |
@Override | |
public boolean getReuseAddress() throws SocketException { | |
if (sc.supportedOptions().contains(StandardSocketOptions.SO_REUSEADDR)) { | |
return getBooleanOption(StandardSocketOptions.SO_REUSEADDR); | |
} | |
return false; | |
} | |
@Override | |
public void close() throws IOException { | |
sc.close(); | |
} | |
@Override | |
public void shutdownInput() throws IOException { | |
try { | |
inputShutdown = true; | |
sc.shutdownInput(); | |
} catch (Exception x) { | |
translateException(x); | |
} | |
} | |
@Override | |
public void shutdownOutput() throws IOException { | |
try { | |
outputShutdown = true; | |
sc.shutdownOutput(); | |
} catch (Exception x) { | |
translateException(x); | |
} | |
} | |
@Override | |
public String toString() { | |
if (sc.isConnected()) | |
return "Socket[addr=" + getInetAddress() + | |
",port=" + getPort() + | |
",localport=" + getLocalPort() + "]"; | |
return "Socket[unconnected]"; | |
} | |
@Override | |
public boolean isConnected() { | |
return sc.isConnected(); | |
} | |
@Override | |
public boolean isBound() { | |
try { | |
return sc.getLocalAddress() != null; | |
} catch (IOException e) { | |
// ignore | |
} | |
return false; | |
} | |
@Override | |
public boolean isClosed() { | |
return !sc.isOpen(); | |
} | |
@Override | |
public boolean isInputShutdown() { | |
return inputShutdown; | |
} | |
@Override | |
public boolean isOutputShutdown() { | |
return outputShutdown; | |
} | |
@Override | |
public <T> Socket setOption(SocketOption<T> name, T value) throws IOException { | |
sc.setOption(name, value); | |
return this; | |
} | |
@Override | |
public <T> T getOption(SocketOption<T> name) throws IOException { | |
return sc.getOption(name); | |
} | |
@Override | |
public Set<SocketOption<?>> supportedOptions() { | |
return sc.supportedOptions(); | |
} | |
private static void translateToSocketException(Exception x) | |
throws SocketException { | |
if (x instanceof SocketException) | |
throw (SocketException) x; | |
Exception nx = x; | |
if (x instanceof ClosedChannelException) | |
nx = new SocketException("Socket is closed"); | |
else if (x instanceof NotYetConnectedException) | |
nx = new SocketException("Socket is not connected"); | |
else if (x instanceof AlreadyBoundException) | |
nx = new SocketException("Already bound"); | |
else if (x instanceof NotYetBoundException) | |
nx = new SocketException("Socket is not bound yet"); | |
else if (x instanceof UnsupportedAddressTypeException) | |
nx = new SocketException("Unsupported address type"); | |
else if (x instanceof UnresolvedAddressException) | |
nx = new SocketException("Unresolved address"); | |
else if (x instanceof IOException) { | |
nx = new SocketException(x.getMessage()); | |
} | |
if (nx != x) | |
nx.initCause(x); | |
if (nx instanceof SocketException) | |
throw (SocketException) nx; | |
else if (nx instanceof RuntimeException) | |
throw (RuntimeException) nx; | |
else | |
throw new Error("Untranslated exception", nx); | |
} | |
private static void translateException(Exception x, boolean unknownHostForUnresolved) throws IOException { | |
if (x instanceof IOException) | |
throw (IOException) x; | |
// Throw UnknownHostException from here since it cannot | |
// be thrown as a SocketException | |
if (unknownHostForUnresolved && | |
(x instanceof UnresolvedAddressException)) { | |
throw new UnknownHostException(); | |
} | |
translateToSocketException(x); | |
} | |
private static void translateException(Exception x) throws IOException { | |
translateException(x, false); | |
} | |
} | |
// sun.nio.ch.DummySocketImpl | |
private static class DummySocketImpl extends SocketImpl { | |
private DummySocketImpl() { | |
} | |
private static <T> T shouldNotGetHere() { | |
throw new InternalError("Should not get here"); | |
} | |
@Override | |
protected void create(boolean stream) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void connect(SocketAddress remote, int millis) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void connect(String host, int port) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void connect(InetAddress address, int port) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void bind(InetAddress host, int port) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void listen(int backlog) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void accept(SocketImpl si) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected InputStream getInputStream() { | |
return shouldNotGetHere(); | |
} | |
@Override | |
protected OutputStream getOutputStream() { | |
return shouldNotGetHere(); | |
} | |
@Override | |
protected int available() { | |
return shouldNotGetHere(); | |
} | |
@Override | |
protected void close() { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected Set<SocketOption<?>> supportedOptions() { | |
return shouldNotGetHere(); | |
} | |
@Override | |
protected <T> void setOption(SocketOption<T> opt, T value) { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected <T> T getOption(SocketOption<T> opt) { | |
return shouldNotGetHere(); | |
} | |
@Override | |
public void setOption(int opt, Object value) { | |
shouldNotGetHere(); | |
} | |
@Override | |
public Object getOption(int opt) { | |
return shouldNotGetHere(); | |
} | |
@Override | |
protected void shutdownInput() { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected void shutdownOutput() { | |
shouldNotGetHere(); | |
} | |
@Override | |
protected boolean supportsUrgentData() { | |
return shouldNotGetHere(); | |
} | |
@Override | |
protected void sendUrgentData(int data) { | |
shouldNotGetHere(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment