Skip to content

Instantly share code, notes, and snippets.

@kran
Last active December 20, 2024 06:01
Show Gist options
  • Save kran/48a9e53e89387c57acb97154a8e51eef to your computer and use it in GitHub Desktop.
Save kran/48a9e53e89387c57acb97154a8e51eef to your computer and use it in GitHub Desktop.
apache dbutils wrapper
import org.apache.commons.dbutils.*;
import org.apache.commons.dbutils.handlers.*;
import javax.sql.DataSource;
import java.math.BigInteger;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.*;
import java.util.AbstractMap.SimpleEntry;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
* <dependency>
* <groupId>commons-dbutils</groupId>
* <artifactId>commons-dbutils</artifactId>
* <version>1.8.1</version>
* </dependency>
*/
public class Duck extends QueryRunner {
private static final Logger logger = Logger.getLogger(Duck.class.getName());
protected final List<String> queryParts = new ArrayList<>();
protected final List<Object> bindValues = new ArrayList<>();
protected final Map<String, Integer> marks = new HashMap<>();
protected Connection txConn = null;
protected Function<String, String> escaper = this.makeEscaper("`");
protected boolean isSnakeCase = true;
protected Duck(DataSource dataSource_) {
super(dataSource_);
}
protected Duck(Connection conn) {
super();
txConn = conn;
}
public static Duck init(DataSource dataSource_) {
return new Duck(dataSource_);
}
public SimpleEntry<String, List<Object>> pairInsert(Map<String, Object> params) {
SimpleEntry<List<String>, List<Object>> kv = pair(params);
String sql = String.join(", ", kv.getKey());
return new SimpleEntry<>(sql, kv.getValue());
}
public SimpleEntry<String, Object[]> pairUpdate(Map<String, Object> params) {
SimpleEntry<List<String>, List<Object>> kv = pair(params);
String sql = kv.getKey().stream()
.map(k -> String.format("%s = ?", k))
.collect(Collectors.joining(", "));
return new SimpleEntry<>(sql, kv.getValue().toArray());
}
public SimpleEntry<List<String>, List<Object>> pair(Map<String, Object> params) {
List<String> keys = new ArrayList<>(params.size());
List<Object> values = new ArrayList<>(params.size());
int index = 0;
for (Map.Entry<String, Object> kv : params.entrySet()) {
String key = escaper.apply(isSnakeCase ? toSnakeCase(kv.getKey()) : kv.getKey());
keys.add(index, key);
values.add(index, kv.getValue());
index++;
}
return new SimpleEntry<>(keys, values);
}
public String getSql() {
return String.join(" ", queryParts);
}
public Object[] getParams() {
return bindValues.toArray();
}
public Duck reset() {
this.queryParts.clear();
this.bindValues.clear();
this.marks.clear();
return this;
}
public Connection getTxConn() {
if (txConn == null) {
throw new SQLQueryException("Not in a transaction context");
}
return txConn;
}
protected Connection selectConnection() throws SQLException {
if (txConn != null) {
return txConn;
}
if (getDataSource() != null) {
return getDataSource().getConnection();
}
throw new SQLQueryException("Unable to obtain database connection");
}
private <T> T executeInConnection(ConnectionCallback<T> action) {
Connection conn = null;
try {
conn = selectConnection();
return action.doInConnection(conn);
} catch (SQLException cause) {
throw new SQLQueryException(cause);
} finally {
if (conn != null && conn != txConn) {
try {
conn.close();
} catch (SQLException ex) {
logger.warning("Failed to close connection: " + ex.getMessage());
}
}
}
}
public <T> T fetch(final ResultSetHandler<T> rst) {
return executeInConnection(conn -> query(conn, getSql(), rst, getParams()));
}
public <T> List<T> fetchBeanList(Class<T> tClass) {
RowProcessor rowProcessor = new BasicRowProcessor(new GenerousBeanProcessor());
ResultSetHandler<List<T>> handler = new BeanListHandler<>(tClass, rowProcessor);
return fetch(handler);
}
public <T> T fetchBean(Class<T> tClass) {
List<T> result = fetchBeanList(tClass);
if (result.size() > 1) {
String errmsg = String.format("Non-unique result: query returned %d rows when expecting exactly one row", result.size());
throw new SQLQueryException(errmsg);
}
return result.isEmpty() ? null : result.get(0);
}
public <K, V> Map<K, V> fetchBeanMap(String columnName, Class<V> vClass) {
BeanMapHandler<K, V> handler = new BeanMapHandler<>(vClass, columnName);
return fetch(handler);
}
public List<Map<String, Object>> fetchMapList() {
MapListHandler handler = new MapListHandler();
return fetch(handler);
}
public Map<String, Object> fetchMap() {
MapHandler handler = new MapHandler();
return fetch(handler);
}
public <T> T fetchScalar(Class<T> tClass) {
ScalarHandler<T> handler = new ScalarHandler<>();
return fetch(handler);
}
public int update() {
return executeInConnection(conn -> update(conn, getSql(), getParams()));
}
public int update(String tableName, Map<String, Object> params, String where, Object... whereParams) {
SimpleEntry<String, Object[]> paramsKv = pairUpdate(params);
String sql = String.format("update %s set %s", escaper.apply(tableName), paramsKv.getKey());
return this.add(sql, paramsKv.getValue()).add("where " + where, whereParams).update();
}
public BigInteger insert() {
return insert(BigInteger.class);
}
public <T> T insert(Class<T> tClass) {
return executeInConnection(conn -> insert(conn, getSql(), new ScalarHandler<>(), getParams()));
}
public BigInteger insert(String tableName, Map<String, Object> params) {
return insert(tableName, params, BigInteger.class);
}
public <T> T insert(String tableName, Map<String, Object> params, Class<T> tClass) {
SimpleEntry<String, List<Object>> kv = pairInsert(params);
String sql = String.format("insert into %s (%s) values (?)", escaper.apply(tableName), kv.getKey());
return this.add(sql, kv.getValue()).insert(tClass);
}
public <T> List<T> insertBatch(String tableName, List<Map<String, Object>> paramsList, Class<T> tClass) {
if (paramsList == null || paramsList.isEmpty()) {
throw new SQLQueryException("Batch params empty");
}
// 获取第一个Map的所有键并固定顺序
Map<String, Object> firstMap = paramsList.get(0);
List<String> keys = new ArrayList<>(firstMap.keySet());
// 构建SQL语句
String fieldStr = keys.stream()
.map(it -> escaper.apply(isSnakeCase ? toSnakeCase(it) : it))
.collect(Collectors.joining(", "));
String placeholders = String.join(", ", Collections.nCopies(keys.size(), "?"));
String sql = String.format("insert into %s (%s) values (%s)",
escaper.apply(tableName), fieldStr, placeholders);
// 构建批量参数数组
Object[][] batchArgs = new Object[paramsList.size()][];
for (int i = 0; i < paramsList.size(); i++) {
Map<String, Object> params = paramsList.get(i);
Object[] values = new Object[keys.size()];
for (int j = 0; j < keys.size(); j++) {
values[j] = params.get(keys.get(j));
}
batchArgs[i] = values;
}
return executeInConnection(conn -> super.insertBatch(sql, new ColumnListHandler<>(), batchArgs));
}
public List<BigInteger> insertBatch(String tableName, List<Map<String, Object>> paramsList) {
return insertBatch(tableName, paramsList, BigInteger.class);
}
public <R> R transaction(Function<Duck, R> action) {
if (txConn != null) {
throw new SQLQueryException("Nested transactions are not allowed");
}
Connection conn = null;
try {
conn = getDataSource().getConnection();
conn.setAutoCommit(false);
Duck tx = new Duck(conn);
R result = action.apply(tx);
conn.commit();
return result;
} catch (Exception e) {
try {
if (conn != null) {
conn.rollback();
}
} catch (SQLException ex) {
throw new SQLQueryException("Failed to rollback transaction", ex);
}
throw new SQLQueryException("Transaction failed", e);
} finally {
try {
if (conn != null) {
conn.setAutoCommit(true);
conn.close();
}
} catch (SQLException e) {
logger.warning("Failed to close connection: " + e.getMessage());
}
}
}
protected Duck copy() {
Duck query = new Duck(getDataSource());
query.txConn = txConn;
query.queryParts.addAll(queryParts);
query.bindValues.addAll(bindValues);
query.marks.putAll(marks);
query.escaper = escaper;
query.isSnakeCase = isSnakeCase;
return query;
}
public Duck escaper(Function<String, String> escaper_) {
Duck query = copy();
query.escaper = escaper_;
return query;
}
public Function<String, String> makeEscaper(String quote) {
return identifier -> {
if (identifier == null) {
throw new SQLQueryException("Identifier cannot be null");
}
if (identifier.trim().isEmpty()) {
throw new SQLQueryException("Identifier cannot be empty");
}
return Arrays.stream(identifier.split("\\."))
.map(part -> quote
+ part.replace(quote, quote+quote)
+ quote)
.collect(Collectors.joining("."));
};
}
public Duck add(String sql, Object... params) {
Duck query = copy();
query.appendParams(sql, params);
return query;
}
public Duck snakeCase(boolean yes) {
Duck duck = copy();
duck.isSnakeCase = yes;
return duck;
}
public Duck mark(String name, String sql) {
Duck query = copy();
if (query.marks.containsKey(name)) {
query.queryParts.set(query.marks.get(name), sql);
} else {
query.queryParts.add(sql);
query.marks.put(name, query.queryParts.size() - 1);
}
return query;
}
protected void appendParams(String sql, Object[] args) {
String[] parts = (sql + " ").split("\\?");
if (parts.length != args.length + 1) {
String errmsg = String.format("Placeholders length (%d) doesn't match parameters length (%d)",
parts.length - 1, args.length);
throw new SQLQueryException(errmsg);
}
for (int i = 0; i < parts.length; i++) {
if (args.length <= i) {
queryParts.add(parts[i]);
return;
}
Object arg = args[i];
if (arg == null) {
queryParts.add(parts[i] + '?');
bindValues.add(null);
} else if (arg instanceof Collection) {
appendArray(parts[i], ((Collection<?>) arg).toArray());
} else if (arg.getClass().isArray()) {
appendArray(parts[i], (Object[]) arg);
} else {
queryParts.add(parts[i] + '?');
bindValues.add(arg);
}
}
}
protected void appendArray(String sql, Object[] arrayArg) {
StringBuilder marks = new StringBuilder();
for (int i = 0; i < arrayArg.length; i++) {
if (i > 0) marks.append(',');
marks.append('?');
}
queryParts.add(sql + marks);
bindValues.addAll(Arrays.asList(arrayArg));
}
public Duck debug() {
logger.info("{" + Duck.class.getSimpleName() + "}\n" + this);
return this;
}
@Override
public String toString() {
return "SQL: " + getSql() + "\nParams: " + Arrays.toString(getParams());
}
public String toSnakeCase(String str) {
String res = str.replaceAll("([A-Z]+)", "_$1").toLowerCase();
if (res.startsWith("_") && !str.startsWith("_")) {
res = res.substring(1);
}
return res;
}
@FunctionalInterface
protected interface ConnectionCallback<T> {
T doInConnection(Connection conn) throws SQLException;
}
public static class SQLQueryException extends RuntimeException {
public SQLQueryException(String message) {
super(message);
}
public SQLQueryException(Throwable cause) {
super(cause);
}
public SQLQueryException(String message, Throwable cause) {
super(message, cause);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment