Skip to content

Instantly share code, notes, and snippets.

@jedvardsson
Last active November 30, 2023 09:20
Show Gist options
  • Save jedvardsson/7ba7bbc94b4951f82da4b590ace725d2 to your computer and use it in GitHub Desktop.
Save jedvardsson/7ba7bbc94b4951f82da4b590ace725d2 to your computer and use it in GitHub Desktop.
A socket factory for connecting to Postgres via unix domain sockets
package com.github.jedvardsson.pgsf;
import org.postgresql.PGProperty;
import org.postgresql.util.PSQLException;
import javax.net.SocketFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.net.*;
import java.nio.channels.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Properties;
import java.util.Set;
// Usage: jdbc:postgresql:///mydatabase?socketFactory=com.github.jedvardsson.pgsf.PostgresSocketFactory</code>
public class PostgresSocketFactory extends SocketFactory {
private final Properties properties;
public PostgresSocketFactory(Properties properties) {
this.properties = properties;
}
@Override
public Socket createSocket() throws IOException {
try {
String host = PGProperty.PG_HOST.getOrNull(properties);
int port = PGProperty.PG_PORT.getInt(properties);
if (host == null || host.isEmpty()) {
String socketFileName = ".s.PGSQL." + port;
List<Path> sockets = List.of(Path.of("/tmp", socketFileName), Path.of("/var/run/postgres", socketFileName));
UnixDomainSocketAddress socketAddress = sockets.stream()
.filter(Files::exists)
.map(UnixDomainSocketAddress::of)
.findFirst().orElseThrow(() -> new IllegalStateException("host or unix_socket was not specified and default sockets not found: " + sockets));
SocketChannel socketChannel = SocketChannel.open(StandardProtocolFamily.UNIX);
socketChannel.configureBlocking(true);
socketChannel.connect(socketAddress);
return new SocketAdaptor(socketChannel);
} else {
SocketFactory.getDefault().createSocket();
}
return super.createSocket();
} catch (PSQLException e) {
throw new RuntimeException(e);
}
}
@Override
public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
throw new UnsupportedOperationException();
}
@Override
public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException, UnknownHostException {
throw new UnsupportedOperationException();
}
@Override
public Socket createSocket(InetAddress host, int port) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException {
throw new UnsupportedOperationException();
}
// sun.nio.ch.SocketAdaptor
@SuppressWarnings("SameParameterValue")
private static class SocketAdaptor extends Socket {
private final SocketChannel sc;
private volatile boolean inputShutdown;
private volatile boolean outputShutdown;
private volatile int timeout;
public SocketAdaptor(SocketChannel sc) throws SocketException {
super(new DummySocketImpl());
this.sc = sc;
}
@Override
public void connect(SocketAddress endpoint) throws IOException {
connect(endpoint, 0);
}
@Override
public void connect(SocketAddress endpoint, int timeout) throws IOException {
// ignore timeout for now
sc.connect(endpoint);
// try (Selector selector = Selector.open()) {
// boolean connected = sc.connect(endpoint);
// if (connected) {
// return;
// }
//
// sc.register(selector, SelectionKey.OP_CONNECT);
// int selectedKeyCount = selector.select(timeout);
// if (selectedKeyCount != 1) {
// throw new RuntimeException("Expected 1 selected key: " + selectedKeyCount);
// }
//
// Set<SelectionKey> selectedKeys = selector.selectedKeys();
// Iterator<SelectionKey> iter = selectedKeys.iterator();
// SelectionKey key = iter.next();
// if (!key.isConnectable()) {
// throw new IllegalArgumentException("Expected selected key to be connectable: " + key);
// }
// iter.remove();
// }
}
@Override
public void bind(SocketAddress bindpoint) throws IOException {
sc.bind(bindpoint);
}
@Override
public InetAddress getInetAddress() {
try {
SocketAddress address = sc.getRemoteAddress();
if (address instanceof InetSocketAddress inetSocketAddress) {
return inetSocketAddress.getAddress();
}
return null;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public InetAddress getLocalAddress() {
try {
SocketAddress address = sc.getLocalAddress();
if (address instanceof InetSocketAddress inetSocketAddress) {
return inetSocketAddress.getAddress();
}
return null;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public int getPort() {
try {
SocketAddress address = sc.getRemoteAddress();
if (address instanceof InetSocketAddress inetSocketAddress) {
return inetSocketAddress.getPort();
}
return 0;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public int getLocalPort() {
try {
SocketAddress address = sc.getLocalAddress();
if (address instanceof InetSocketAddress inetSocketAddress) {
return inetSocketAddress.getPort();
}
return -1;
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public SocketAddress getRemoteSocketAddress() {
try {
return sc.getRemoteAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public SocketAddress getLocalSocketAddress() {
try {
return sc.getLocalAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public SocketChannel getChannel() {
return sc;
}
@Override
public InputStream getInputStream() throws IOException {
if (!sc.isOpen())
throw new SocketException("Socket is closed");
if (!sc.isConnected())
throw new SocketException("Socket is not connected");
if (inputShutdown)
throw new SocketException("Socket input is shutdown");
// ignore timeout for now
return Channels.newInputStream(sc);
}
@Override
public OutputStream getOutputStream() throws IOException {
if (!sc.isOpen())
throw new SocketException("Socket is closed");
if (!sc.isConnected())
throw new SocketException("Socket is not connected");
if (outputShutdown)
throw new SocketException("Socket output is shutdown");
// ignore timeout for now
return Channels.newOutputStream(sc);
}
private void setBooleanOption(SocketOption<Boolean> name, boolean value) throws SocketException {
try {
sc.setOption(name, value);
} catch (IOException x) {
translateToSocketException(x);
}
}
private void setIntOption(SocketOption<Integer> name, int value) throws SocketException {
try {
sc.setOption(name, value);
} catch (IOException e) {
translateToSocketException(e);
}
}
private boolean getBooleanOption(SocketOption<Boolean> name) throws SocketException {
try {
return sc.getOption(name).booleanValue();
} catch (IOException e) {
translateToSocketException(e);
return false; // keep compiler happy
}
}
private int getIntOption(SocketOption<Integer> name) throws SocketException {
try {
return sc.getOption(name).intValue();
} catch (IOException e) {
translateToSocketException(e);
return -1; // keep compiler happy
}
}
@Override
public void setTcpNoDelay(boolean on) throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.TCP_NODELAY)) {
setBooleanOption(StandardSocketOptions.TCP_NODELAY, on);
}
}
@Override
public boolean getTcpNoDelay() throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.TCP_NODELAY)) {
return getBooleanOption(StandardSocketOptions.TCP_NODELAY);
}
return false;
}
@Override
public void setSoLinger(boolean on, int linger) throws SocketException {
if (!on)
linger = -1;
setIntOption(StandardSocketOptions.SO_LINGER, linger);
}
@Override
public int getSoLinger() throws SocketException {
return getIntOption(StandardSocketOptions.SO_LINGER);
}
@Override
public void sendUrgentData(int data) throws IOException {
// int n = sc.sendOutOfBandData((byte) data);
// if (n == 0)
// throw new IOException("Socket buffer full");
throw new SocketException("Urgent data not supported");
}
private static final SocketOption<Boolean> SO_OOBINLINE =
new SocketOption<Boolean>() {
public String name() {
return "SO_OOBINLINE";
}
public Class<Boolean> type() {
return Boolean.class;
}
public String toString() {
return name();
}
};
@Override
public void setOOBInline(boolean on) throws SocketException {
if (sc.supportedOptions().contains(SO_OOBINLINE)) {
setBooleanOption(SO_OOBINLINE, on);
}
}
@Override
public boolean getOOBInline() throws SocketException {
if (sc.supportedOptions().contains(SO_OOBINLINE)) {
return getBooleanOption(SO_OOBINLINE);
}
return false;
}
@Override
public void setSoTimeout(int timeout) throws SocketException {
if (!sc.isOpen())
throw new SocketException("Socket is closed");
if (timeout < 0)
throw new IllegalArgumentException("timeout < 0");
this.timeout = timeout;
}
@Override
public int getSoTimeout() throws SocketException {
if (!sc.isOpen())
throw new SocketException("Socket is closed");
return timeout;
}
@Override
public void setSendBufferSize(int size) throws SocketException {
// size 0 valid for SocketChannel, invalid for Socket
if (size <= 0)
throw new IllegalArgumentException("Invalid send size");
setIntOption(StandardSocketOptions.SO_SNDBUF, size);
}
@Override
public int getSendBufferSize() throws SocketException {
return getIntOption(StandardSocketOptions.SO_SNDBUF);
}
@Override
public void setReceiveBufferSize(int size) throws SocketException {
// size 0 valid for SocketChannel, invalid for Socket
if (size <= 0)
throw new IllegalArgumentException("Invalid receive size");
setIntOption(StandardSocketOptions.SO_RCVBUF, size);
}
@Override
public int getReceiveBufferSize() throws SocketException {
return getIntOption(StandardSocketOptions.SO_RCVBUF);
}
@Override
public void setKeepAlive(boolean on) throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.SO_KEEPALIVE)) {
setBooleanOption(StandardSocketOptions.SO_KEEPALIVE, on);
}
}
@Override
public boolean getKeepAlive() throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.SO_KEEPALIVE)) {
return getBooleanOption(StandardSocketOptions.SO_KEEPALIVE);
}
return false;
}
@Override
public void setTrafficClass(int tc) throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.IP_TOS)) {
setIntOption(StandardSocketOptions.IP_TOS, tc);
}
}
@Override
public int getTrafficClass() throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.IP_TOS)) {
return getIntOption(StandardSocketOptions.IP_TOS);
}
return 0;
}
@Override
public void setReuseAddress(boolean on) throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.SO_REUSEADDR)) {
setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on);
}
}
@Override
public boolean getReuseAddress() throws SocketException {
if (sc.supportedOptions().contains(StandardSocketOptions.SO_REUSEADDR)) {
return getBooleanOption(StandardSocketOptions.SO_REUSEADDR);
}
return false;
}
@Override
public void close() throws IOException {
sc.close();
}
@Override
public void shutdownInput() throws IOException {
try {
inputShutdown = true;
sc.shutdownInput();
} catch (Exception x) {
translateException(x);
}
}
@Override
public void shutdownOutput() throws IOException {
try {
outputShutdown = true;
sc.shutdownOutput();
} catch (Exception x) {
translateException(x);
}
}
@Override
public String toString() {
if (sc.isConnected())
return "Socket[addr=" + getInetAddress() +
",port=" + getPort() +
",localport=" + getLocalPort() + "]";
return "Socket[unconnected]";
}
@Override
public boolean isConnected() {
return sc.isConnected();
}
@Override
public boolean isBound() {
try {
return sc.getLocalAddress() != null;
} catch (IOException e) {
// ignore
}
return false;
}
@Override
public boolean isClosed() {
return !sc.isOpen();
}
@Override
public boolean isInputShutdown() {
return inputShutdown;
}
@Override
public boolean isOutputShutdown() {
return outputShutdown;
}
@Override
public <T> Socket setOption(SocketOption<T> name, T value) throws IOException {
sc.setOption(name, value);
return this;
}
@Override
public <T> T getOption(SocketOption<T> name) throws IOException {
return sc.getOption(name);
}
@Override
public Set<SocketOption<?>> supportedOptions() {
return sc.supportedOptions();
}
private static void translateToSocketException(Exception x)
throws SocketException {
if (x instanceof SocketException)
throw (SocketException) x;
Exception nx = x;
if (x instanceof ClosedChannelException)
nx = new SocketException("Socket is closed");
else if (x instanceof NotYetConnectedException)
nx = new SocketException("Socket is not connected");
else if (x instanceof AlreadyBoundException)
nx = new SocketException("Already bound");
else if (x instanceof NotYetBoundException)
nx = new SocketException("Socket is not bound yet");
else if (x instanceof UnsupportedAddressTypeException)
nx = new SocketException("Unsupported address type");
else if (x instanceof UnresolvedAddressException)
nx = new SocketException("Unresolved address");
else if (x instanceof IOException) {
nx = new SocketException(x.getMessage());
}
if (nx != x)
nx.initCause(x);
if (nx instanceof SocketException)
throw (SocketException) nx;
else if (nx instanceof RuntimeException)
throw (RuntimeException) nx;
else
throw new Error("Untranslated exception", nx);
}
private static void translateException(Exception x, boolean unknownHostForUnresolved) throws IOException {
if (x instanceof IOException)
throw (IOException) x;
// Throw UnknownHostException from here since it cannot
// be thrown as a SocketException
if (unknownHostForUnresolved &&
(x instanceof UnresolvedAddressException)) {
throw new UnknownHostException();
}
translateToSocketException(x);
}
private static void translateException(Exception x) throws IOException {
translateException(x, false);
}
}
// sun.nio.ch.DummySocketImpl
private static class DummySocketImpl extends SocketImpl {
private DummySocketImpl() {
}
private static <T> T shouldNotGetHere() {
throw new InternalError("Should not get here");
}
@Override
protected void create(boolean stream) {
shouldNotGetHere();
}
@Override
protected void connect(SocketAddress remote, int millis) {
shouldNotGetHere();
}
@Override
protected void connect(String host, int port) {
shouldNotGetHere();
}
@Override
protected void connect(InetAddress address, int port) {
shouldNotGetHere();
}
@Override
protected void bind(InetAddress host, int port) {
shouldNotGetHere();
}
@Override
protected void listen(int backlog) {
shouldNotGetHere();
}
@Override
protected void accept(SocketImpl si) {
shouldNotGetHere();
}
@Override
protected InputStream getInputStream() {
return shouldNotGetHere();
}
@Override
protected OutputStream getOutputStream() {
return shouldNotGetHere();
}
@Override
protected int available() {
return shouldNotGetHere();
}
@Override
protected void close() {
shouldNotGetHere();
}
@Override
protected Set<SocketOption<?>> supportedOptions() {
return shouldNotGetHere();
}
@Override
protected <T> void setOption(SocketOption<T> opt, T value) {
shouldNotGetHere();
}
@Override
protected <T> T getOption(SocketOption<T> opt) {
return shouldNotGetHere();
}
@Override
public void setOption(int opt, Object value) {
shouldNotGetHere();
}
@Override
public Object getOption(int opt) {
return shouldNotGetHere();
}
@Override
protected void shutdownInput() {
shouldNotGetHere();
}
@Override
protected void shutdownOutput() {
shouldNotGetHere();
}
@Override
protected boolean supportsUrgentData() {
return shouldNotGetHere();
}
@Override
protected void sendUrgentData(int data) {
shouldNotGetHere();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment