Last active
December 20, 2024 06:01
-
-
Save kran/48a9e53e89387c57acb97154a8e51eef to your computer and use it in GitHub Desktop.
apache dbutils wrapper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.commons.dbutils.*; | |
import org.apache.commons.dbutils.handlers.*; | |
import javax.sql.DataSource; | |
import java.math.BigInteger; | |
import java.sql.Connection; | |
import java.sql.SQLException; | |
import java.util.*; | |
import java.util.AbstractMap.SimpleEntry; | |
import java.util.function.Function; | |
import java.util.logging.Logger; | |
import java.util.stream.Collectors; | |
/** | |
* <dependency> | |
* <groupId>commons-dbutils</groupId> | |
* <artifactId>commons-dbutils</artifactId> | |
* <version>1.8.1</version> | |
* </dependency> | |
*/ | |
public class Duck extends QueryRunner { | |
private static final Logger logger = Logger.getLogger(Duck.class.getName()); | |
protected final List<String> queryParts = new ArrayList<>(); | |
protected final List<Object> bindValues = new ArrayList<>(); | |
protected final Map<String, Integer> marks = new HashMap<>(); | |
protected Connection txConn = null; | |
protected Function<String, String> escaper = this.makeEscaper("`"); | |
protected boolean isSnakeCase = true; | |
protected Duck(DataSource dataSource_) { | |
super(dataSource_); | |
} | |
protected Duck(Connection conn) { | |
super(); | |
txConn = conn; | |
} | |
public static Duck init(DataSource dataSource_) { | |
return new Duck(dataSource_); | |
} | |
public SimpleEntry<String, List<Object>> pairInsert(Map<String, Object> params) { | |
SimpleEntry<List<String>, List<Object>> kv = pair(params); | |
String sql = String.join(", ", kv.getKey()); | |
return new SimpleEntry<>(sql, kv.getValue()); | |
} | |
public SimpleEntry<String, Object[]> pairUpdate(Map<String, Object> params) { | |
SimpleEntry<List<String>, List<Object>> kv = pair(params); | |
String sql = kv.getKey().stream() | |
.map(k -> String.format("%s = ?", k)) | |
.collect(Collectors.joining(", ")); | |
return new SimpleEntry<>(sql, kv.getValue().toArray()); | |
} | |
public SimpleEntry<List<String>, List<Object>> pair(Map<String, Object> params) { | |
List<String> keys = new ArrayList<>(params.size()); | |
List<Object> values = new ArrayList<>(params.size()); | |
int index = 0; | |
for (Map.Entry<String, Object> kv : params.entrySet()) { | |
String key = escaper.apply(isSnakeCase ? toSnakeCase(kv.getKey()) : kv.getKey()); | |
keys.add(index, key); | |
values.add(index, kv.getValue()); | |
index++; | |
} | |
return new SimpleEntry<>(keys, values); | |
} | |
public String getSql() { | |
return String.join(" ", queryParts); | |
} | |
public Object[] getParams() { | |
return bindValues.toArray(); | |
} | |
public Duck reset() { | |
this.queryParts.clear(); | |
this.bindValues.clear(); | |
this.marks.clear(); | |
return this; | |
} | |
public Connection getTxConn() { | |
if (txConn == null) { | |
throw new SQLQueryException("Not in a transaction context"); | |
} | |
return txConn; | |
} | |
protected Connection selectConnection() throws SQLException { | |
if (txConn != null) { | |
return txConn; | |
} | |
if (getDataSource() != null) { | |
return getDataSource().getConnection(); | |
} | |
throw new SQLQueryException("Unable to obtain database connection"); | |
} | |
private <T> T executeInConnection(ConnectionCallback<T> action) { | |
Connection conn = null; | |
try { | |
conn = selectConnection(); | |
return action.doInConnection(conn); | |
} catch (SQLException cause) { | |
throw new SQLQueryException(cause); | |
} finally { | |
if (conn != null && conn != txConn) { | |
try { | |
conn.close(); | |
} catch (SQLException ex) { | |
logger.warning("Failed to close connection: " + ex.getMessage()); | |
} | |
} | |
} | |
} | |
public <T> T fetch(final ResultSetHandler<T> rst) { | |
return executeInConnection(conn -> query(conn, getSql(), rst, getParams())); | |
} | |
public <T> List<T> fetchBeanList(Class<T> tClass) { | |
RowProcessor rowProcessor = new BasicRowProcessor(new GenerousBeanProcessor()); | |
ResultSetHandler<List<T>> handler = new BeanListHandler<>(tClass, rowProcessor); | |
return fetch(handler); | |
} | |
public <T> T fetchBean(Class<T> tClass) { | |
List<T> result = fetchBeanList(tClass); | |
if (result.size() > 1) { | |
String errmsg = String.format("Non-unique result: query returned %d rows when expecting exactly one row", result.size()); | |
throw new SQLQueryException(errmsg); | |
} | |
return result.isEmpty() ? null : result.get(0); | |
} | |
public <K, V> Map<K, V> fetchBeanMap(String columnName, Class<V> vClass) { | |
BeanMapHandler<K, V> handler = new BeanMapHandler<>(vClass, columnName); | |
return fetch(handler); | |
} | |
public List<Map<String, Object>> fetchMapList() { | |
MapListHandler handler = new MapListHandler(); | |
return fetch(handler); | |
} | |
public Map<String, Object> fetchMap() { | |
MapHandler handler = new MapHandler(); | |
return fetch(handler); | |
} | |
public <T> T fetchScalar(Class<T> tClass) { | |
ScalarHandler<T> handler = new ScalarHandler<>(); | |
return fetch(handler); | |
} | |
public int update() { | |
return executeInConnection(conn -> update(conn, getSql(), getParams())); | |
} | |
public int update(String tableName, Map<String, Object> params, String where, Object... whereParams) { | |
SimpleEntry<String, Object[]> paramsKv = pairUpdate(params); | |
String sql = String.format("update %s set %s", escaper.apply(tableName), paramsKv.getKey()); | |
return this.add(sql, paramsKv.getValue()).add("where " + where, whereParams).update(); | |
} | |
public BigInteger insert() { | |
return insert(BigInteger.class); | |
} | |
public <T> T insert(Class<T> tClass) { | |
return executeInConnection(conn -> insert(conn, getSql(), new ScalarHandler<>(), getParams())); | |
} | |
public BigInteger insert(String tableName, Map<String, Object> params) { | |
return insert(tableName, params, BigInteger.class); | |
} | |
public <T> T insert(String tableName, Map<String, Object> params, Class<T> tClass) { | |
SimpleEntry<String, List<Object>> kv = pairInsert(params); | |
String sql = String.format("insert into %s (%s) values (?)", escaper.apply(tableName), kv.getKey()); | |
return this.add(sql, kv.getValue()).insert(tClass); | |
} | |
public <T> List<T> insertBatch(String tableName, List<Map<String, Object>> paramsList, Class<T> tClass) { | |
if (paramsList == null || paramsList.isEmpty()) { | |
throw new SQLQueryException("Batch params empty"); | |
} | |
// 获取第一个Map的所有键并固定顺序 | |
Map<String, Object> firstMap = paramsList.get(0); | |
List<String> keys = new ArrayList<>(firstMap.keySet()); | |
// 构建SQL语句 | |
String fieldStr = keys.stream() | |
.map(it -> escaper.apply(isSnakeCase ? toSnakeCase(it) : it)) | |
.collect(Collectors.joining(", ")); | |
String placeholders = String.join(", ", Collections.nCopies(keys.size(), "?")); | |
String sql = String.format("insert into %s (%s) values (%s)", | |
escaper.apply(tableName), fieldStr, placeholders); | |
// 构建批量参数数组 | |
Object[][] batchArgs = new Object[paramsList.size()][]; | |
for (int i = 0; i < paramsList.size(); i++) { | |
Map<String, Object> params = paramsList.get(i); | |
Object[] values = new Object[keys.size()]; | |
for (int j = 0; j < keys.size(); j++) { | |
values[j] = params.get(keys.get(j)); | |
} | |
batchArgs[i] = values; | |
} | |
return executeInConnection(conn -> super.insertBatch(sql, new ColumnListHandler<>(), batchArgs)); | |
} | |
public List<BigInteger> insertBatch(String tableName, List<Map<String, Object>> paramsList) { | |
return insertBatch(tableName, paramsList, BigInteger.class); | |
} | |
public <R> R transaction(Function<Duck, R> action) { | |
if (txConn != null) { | |
throw new SQLQueryException("Nested transactions are not allowed"); | |
} | |
Connection conn = null; | |
try { | |
conn = getDataSource().getConnection(); | |
conn.setAutoCommit(false); | |
Duck tx = new Duck(conn); | |
R result = action.apply(tx); | |
conn.commit(); | |
return result; | |
} catch (Exception e) { | |
try { | |
if (conn != null) { | |
conn.rollback(); | |
} | |
} catch (SQLException ex) { | |
throw new SQLQueryException("Failed to rollback transaction", ex); | |
} | |
throw new SQLQueryException("Transaction failed", e); | |
} finally { | |
try { | |
if (conn != null) { | |
conn.setAutoCommit(true); | |
conn.close(); | |
} | |
} catch (SQLException e) { | |
logger.warning("Failed to close connection: " + e.getMessage()); | |
} | |
} | |
} | |
protected Duck copy() { | |
Duck query = new Duck(getDataSource()); | |
query.txConn = txConn; | |
query.queryParts.addAll(queryParts); | |
query.bindValues.addAll(bindValues); | |
query.marks.putAll(marks); | |
query.escaper = escaper; | |
query.isSnakeCase = isSnakeCase; | |
return query; | |
} | |
public Duck escaper(Function<String, String> escaper_) { | |
Duck query = copy(); | |
query.escaper = escaper_; | |
return query; | |
} | |
public Function<String, String> makeEscaper(String quote) { | |
return identifier -> { | |
if (identifier == null) { | |
throw new SQLQueryException("Identifier cannot be null"); | |
} | |
if (identifier.trim().isEmpty()) { | |
throw new SQLQueryException("Identifier cannot be empty"); | |
} | |
return Arrays.stream(identifier.split("\\.")) | |
.map(part -> quote | |
+ part.replace(quote, quote+quote) | |
+ quote) | |
.collect(Collectors.joining(".")); | |
}; | |
} | |
public Duck add(String sql, Object... params) { | |
Duck query = copy(); | |
query.appendParams(sql, params); | |
return query; | |
} | |
public Duck snakeCase(boolean yes) { | |
Duck duck = copy(); | |
duck.isSnakeCase = yes; | |
return duck; | |
} | |
public Duck mark(String name, String sql) { | |
Duck query = copy(); | |
if (query.marks.containsKey(name)) { | |
query.queryParts.set(query.marks.get(name), sql); | |
} else { | |
query.queryParts.add(sql); | |
query.marks.put(name, query.queryParts.size() - 1); | |
} | |
return query; | |
} | |
protected void appendParams(String sql, Object[] args) { | |
String[] parts = (sql + " ").split("\\?"); | |
if (parts.length != args.length + 1) { | |
String errmsg = String.format("Placeholders length (%d) doesn't match parameters length (%d)", | |
parts.length - 1, args.length); | |
throw new SQLQueryException(errmsg); | |
} | |
for (int i = 0; i < parts.length; i++) { | |
if (args.length <= i) { | |
queryParts.add(parts[i]); | |
return; | |
} | |
Object arg = args[i]; | |
if (arg == null) { | |
queryParts.add(parts[i] + '?'); | |
bindValues.add(null); | |
} else if (arg instanceof Collection) { | |
appendArray(parts[i], ((Collection<?>) arg).toArray()); | |
} else if (arg.getClass().isArray()) { | |
appendArray(parts[i], (Object[]) arg); | |
} else { | |
queryParts.add(parts[i] + '?'); | |
bindValues.add(arg); | |
} | |
} | |
} | |
protected void appendArray(String sql, Object[] arrayArg) { | |
StringBuilder marks = new StringBuilder(); | |
for (int i = 0; i < arrayArg.length; i++) { | |
if (i > 0) marks.append(','); | |
marks.append('?'); | |
} | |
queryParts.add(sql + marks); | |
bindValues.addAll(Arrays.asList(arrayArg)); | |
} | |
public Duck debug() { | |
logger.info("{" + Duck.class.getSimpleName() + "}\n" + this); | |
return this; | |
} | |
@Override | |
public String toString() { | |
return "SQL: " + getSql() + "\nParams: " + Arrays.toString(getParams()); | |
} | |
public String toSnakeCase(String str) { | |
String res = str.replaceAll("([A-Z]+)", "_$1").toLowerCase(); | |
if (res.startsWith("_") && !str.startsWith("_")) { | |
res = res.substring(1); | |
} | |
return res; | |
} | |
@FunctionalInterface | |
protected interface ConnectionCallback<T> { | |
T doInConnection(Connection conn) throws SQLException; | |
} | |
public static class SQLQueryException extends RuntimeException { | |
public SQLQueryException(String message) { | |
super(message); | |
} | |
public SQLQueryException(Throwable cause) { | |
super(cause); | |
} | |
public SQLQueryException(String message, Throwable cause) { | |
super(message, cause); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment