Last active
September 12, 2024 19:28
-
-
Save robertmryan/515bbfc751dd97f5d0745dceee32d3b4 to your computer and use it in GitHub Desktop.
Minimalist SQLite wrapper for Swift
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Database.swift | |
// | |
// Created by Robert Ryan on 3/8/19. | |
// Copyright © 2019 Robert Ryan. All rights reserved. | |
// | |
import Foundation | |
import SQLite3 | |
private let SQLITE_STATIC = unsafeBitCast(0, to: sqlite3_destructor_type.self) | |
private let SQLITE_TRANSIENT = unsafeBitCast(-1, to: sqlite3_destructor_type.self) | |
// MARK: - Database | |
/// Thin wrapper for SQLite C interface | |
public class Database { | |
// MARK: - Properties | |
/// The URL for the database | |
let fileURL: URL | |
/// The `sqlite3_open` options | |
public var options: OpenOptions | |
/// A `DateFormatter` for writing dates to the database | |
public static let dateFormatter: DateFormatter = { | |
let formatter = DateFormatter() | |
formatter.locale = .posix | |
formatter.dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSX" | |
formatter.timeZone = TimeZone(secondsFromGMT: 0) | |
return formatter | |
}() | |
/// The SQLite database pointer | |
private var database: OpaquePointer? | |
/// Array of prepared statements that have not yet been finalized | |
private var openStatements: [Statement] = [] | |
// MARK: - Initialization | |
/// Database initializer | |
/// | |
/// Note: You must still `open` this database before using it. | |
public init(fileURL: URL, options: OpenOptions = .default) { | |
self.fileURL = fileURL | |
self.options = options | |
} | |
/// Deinitializer that will finalize any open statements and then close the database if not already closed. | |
deinit { | |
finalizeStatements() | |
try? close() | |
} | |
} | |
// MARK: - Types | |
public extension Database { | |
enum DatabaseError: Error { | |
case failed(ReturnCode, String) | |
case notOpened | |
case noStatementPrepared | |
case closeFailed | |
} | |
struct OpenOptions: OptionSet, Sendable { | |
public let rawValue: Int32 | |
public static let readOnly = OpenOptions(rawValue: SQLITE_OPEN_READONLY) | |
public static let readWrite = OpenOptions(rawValue: SQLITE_OPEN_READWRITE) | |
public static let create = OpenOptions(rawValue: SQLITE_OPEN_CREATE) | |
public static let noMutex = OpenOptions(rawValue: SQLITE_OPEN_NOMUTEX) | |
public static let fullMutex = OpenOptions(rawValue: SQLITE_OPEN_FULLMUTEX) | |
public static let sharedCache = OpenOptions(rawValue: SQLITE_OPEN_SHAREDCACHE) | |
public static let privateCache = OpenOptions(rawValue: SQLITE_OPEN_PRIVATECACHE) | |
public static let `default`: OpenOptions = [.readWrite, .create] | |
public init(rawValue: Int32) { | |
self.rawValue = rawValue | |
} | |
} | |
enum ReturnCode: Equatable, Sendable { | |
// non error codes | |
case ok | |
case done | |
case row | |
// error codes | |
case auth | |
case busy | |
case cantOpen | |
case constraint | |
case corrupt | |
case empty | |
case error | |
case fail | |
case format | |
case full | |
case `internal` | |
case interrupt | |
case ioerr | |
case locked | |
case mismatch | |
case misuse | |
case nolfs | |
case nomem | |
case notadb | |
case notfound | |
case notice | |
case perm | |
case `protocol` | |
case range | |
case readonly | |
case schema | |
case toobig | |
case warning | |
case unknown(Int32) | |
static func code(for code: Int32) -> ReturnCode { | |
switch code { | |
case SQLITE_OK: return .ok | |
case SQLITE_DONE: return .done | |
case SQLITE_ROW: return .row | |
case SQLITE_AUTH: return .auth | |
case SQLITE_BUSY: return .busy | |
case SQLITE_CANTOPEN: return .cantOpen | |
case SQLITE_CONSTRAINT: return .constraint | |
case SQLITE_CORRUPT: return .corrupt | |
case SQLITE_EMPTY: return .empty | |
case SQLITE_ERROR: return .error | |
case SQLITE_FAIL: return .fail | |
case SQLITE_FORMAT: return .format | |
case SQLITE_FULL: return .full | |
case SQLITE_INTERNAL: return .internal | |
case SQLITE_INTERRUPT: return .interrupt | |
case SQLITE_IOERR: return .ioerr | |
case SQLITE_LOCKED: return .locked | |
case SQLITE_MISMATCH: return .mismatch | |
case SQLITE_MISUSE: return .misuse | |
case SQLITE_NOLFS: return .nolfs | |
case SQLITE_NOMEM: return .nomem | |
case SQLITE_NOTADB: return .notadb | |
case SQLITE_NOTFOUND: return .notfound | |
case SQLITE_NOTICE: return .notice | |
case SQLITE_PERM: return .perm | |
case SQLITE_PROTOCOL: return .protocol | |
case SQLITE_RANGE: return .range | |
case SQLITE_READONLY: return .readonly | |
case SQLITE_SCHEMA: return .schema | |
case SQLITE_TOOBIG: return .toobig | |
case SQLITE_WARNING: return .warning | |
default: return .unknown(code) | |
} | |
} | |
} | |
} | |
// MARK: - Public methods | |
public extension Database { | |
/// Open database | |
func open() throws { | |
do { | |
try call { sqlite3_open_v2(fileURL.path, &database, options.rawValue, nil) } | |
} catch { | |
try? close() | |
throw error | |
} | |
} | |
/// Close database | |
func close() throws { | |
if database == nil { return } | |
finalizeStatements() | |
try call { | |
defer { database = nil } | |
return sqlite3_close(database) | |
} | |
} | |
/// Execute statement | |
/// | |
/// - Parameter sql: SQL to be performed. | |
/// - Throws: SQLite errors. | |
func exec(_ sql: String) throws { | |
guard database != nil else { throw DatabaseError.notOpened } | |
try call { sqlite3_exec(database, sql, nil, nil, nil) } | |
} | |
/// Prepare SQL | |
/// | |
/// - Parameters: | |
/// - sql: SQL to be prepared | |
/// - parameters: Any parameters to be bound to any `?` in the SQL. | |
/// - Returns: The prepared statement. | |
/// - Throws: SQLite errors. | |
func prepare(_ sql: String, parameters: [DatabaseBindable?]? = nil) throws -> Statement { | |
guard database != nil else { throw DatabaseError.notOpened } | |
var stmt: OpaquePointer? | |
try call { sqlite3_prepare_v2(database, sql, -1, &stmt, nil) } | |
let statement = Statement(database: self, statement: stmt!) | |
openStatements.append(statement) | |
try statement.bind(parameters) | |
return statement | |
} | |
/// The `rowid` of the last row inserted | |
/// | |
/// - Returns: The `rowid`. | |
func lastRowId() -> Int64 { | |
sqlite3_last_insert_rowid(database) | |
} | |
/// Returns number of rows changed by last `INSERT`, `UPDATE`, or `DELETE` statement. | |
/// | |
/// - Returns: Number of rows changed. | |
func changes() -> Int32 { | |
sqlite3_changes(database) | |
} | |
/// Returns number of rows changed `INSERT`, `UPDATE`, or `DELETE` statements since the database was opened. | |
/// | |
/// - Returns: Number of rows changed. | |
func totalChanges() -> Int32 { | |
sqlite3_total_changes(database) | |
} | |
/// Finalize a previously prepared statement | |
/// | |
/// - Parameter statement: The previously prepared statement. | |
/// - Throws: SQLite error. | |
func finalize(_ statement: Statement) throws { | |
guard let index = openStatements.firstIndex(where: { $0.sqlite3_stmt == statement.sqlite3_stmt }) else { | |
return | |
} | |
openStatements.remove(at: index) | |
try call { | |
defer { statement.sqlite3_stmt = nil } | |
return sqlite3_finalize(statement.sqlite3_stmt) | |
} | |
} | |
/// The version of SQLite being used. | |
/// | |
/// - Returns: Version string. | |
func version() -> String? { | |
sqlite3_libversion() | |
.flatMap { String(cString: $0) } | |
} | |
} | |
// MARK: Private methods | |
fileprivate extension Database { | |
/// Call block containing SQLite C function | |
/// | |
/// - Parameter block: Block that returns value from SQLite C function. | |
/// - Returns: Returns return value from that C function if it returned `.ok`, `.done`, or `.row`. | |
/// - Throws: SQLite error. | |
@discardableResult | |
func call(block: () -> (Int32)) throws -> Database.ReturnCode { | |
let result = Database.ReturnCode.code(for: block()) | |
switch result { | |
case .ok, .done, .row: | |
return result | |
default: | |
let message = String(cString: sqlite3_errmsg(database)) | |
throw DatabaseError.failed(result, message) | |
} | |
} | |
/// Finalize all open statements (those prepared but not yet finalized). | |
func finalizeStatements() { | |
for statement in openStatements { | |
try? finalize(statement) | |
} | |
} | |
} | |
// MARK: - Statement | |
/// SQLite statement. | |
public class Statement { | |
public fileprivate(set) var sqlite3_stmt: OpaquePointer? | |
private weak var database: Database? | |
init(database: Database, statement: OpaquePointer) { | |
self.database = database | |
self.sqlite3_stmt = statement | |
} | |
deinit { | |
try? database?.finalize(self) | |
} | |
} | |
// MARK: Public methods | |
public extension Statement { | |
/// Bind array of parameters to `?` placeholders in SQL | |
/// | |
/// - Parameter parameters: The array of parameters. | |
/// - Throws: SQLite error. | |
func bind(_ parameters: [DatabaseBindable?]?) throws { | |
try parameters?.enumerated().forEach { index, value in | |
let offset = Int32(index + 1) | |
if let value = value { | |
try database?.call { value.bind(to: self, offset: offset) } | |
} else { | |
try database?.call { sqlite3_bind_null(sqlite3_stmt, offset) } | |
} | |
} | |
} | |
@discardableResult | |
/// Perform the prepared statement. | |
/// | |
/// - Returns: The return code if `.done`, `.row` (or `.ok`, which it never can be). | |
/// - Throws: The SQLite error if return code is not one of the aforementioned values. | |
func step() throws -> Database.ReturnCode { | |
guard | |
let database = database, | |
let statement = sqlite3_stmt | |
else { | |
throw Database.DatabaseError.notOpened | |
} | |
return try database.call { sqlite3_step(statement) } | |
} | |
/// Reset the values bound to this prepared statement. | |
/// | |
/// Used if you want to bind new values and perform the statement again without re-preparing it. | |
/// | |
/// - Throws: SQLite error. | |
func reset() throws { | |
guard let database = database, | |
let statement = sqlite3_stmt else { throw Database.DatabaseError.notOpened } | |
try database.call { sqlite3_reset(statement) } | |
} | |
/// Determines if the particular column value is `NULL` or not. | |
/// | |
/// - Parameter index: The column index number. | |
func isNull(index: Int32) -> Bool { | |
sqlite3_column_type(sqlite3_stmt, index) == SQLITE_NULL | |
} | |
/// Retrieve the value returned for a column of the particular `index`. | |
/// | |
/// - Parameters: | |
/// - type: The type to be returned for the column (e.g. `Int.self`). | |
/// - index: The zero-based column index number. | |
/// - Returns: Returns the value found at the specified column index. If the value cannot be converted to that type, it will return `nil`. | |
/// - Throws: The SQLite error if return code is not one of the aforementioned values. | |
func column<T: DatabaseBindable>(_ type: T.Type, index: Int32) -> T? { | |
T(from: self, index: index) | |
} | |
/// Retrieve the name of the column of the particular `index`. | |
/// | |
/// - Parameter index: The zero-based column index number. | |
/// - Returns: The name of the column or `nil` if it couldn't determine the name. | |
func columnName(index: Int32) -> String? { | |
sqlite3_column_name(sqlite3_stmt, index) | |
.flatMap { String(cString: $0) } | |
} | |
/// Retrieve the origin name of the column of the particular `index`. | |
/// | |
/// - Parameter index: The zero-based column index number. | |
/// - Returns: The name of the column or `nil` if it couldn't determine the name. | |
func columnOriginName(index: Int32) -> String? { | |
sqlite3_column_origin_name(sqlite3_stmt, index) | |
.flatMap { String(cString: $0) } | |
} | |
/// Retrieve the name of the table associated with the column of the particular `index`. | |
/// | |
/// - Parameter index: The zero-based column index number. | |
/// - Returns: The name of the column or `nil` if it couldn't determine the name. | |
func columnTableName(index: Int32) -> String? { | |
sqlite3_column_table_name(sqlite3_stmt, index) | |
.flatMap { String(cString: $0) } | |
} | |
/// Retrieve the name of the table associated with the column of the particular `index`. | |
/// | |
/// - Parameter index: The zero-based column index number. | |
/// - Returns: The name of the column or `nil` if it couldn't determine the name. | |
func columnDatabaseName(index: Int32) -> String? { | |
sqlite3_column_database_name(sqlite3_stmt, index) | |
.flatMap { String(cString: $0) } | |
} | |
} | |
// MARK: - Data binding protocol | |
public protocol DatabaseBindable { | |
/// Initializer used when returning value from result set of performed SQL `SELECT` statement. | |
/// | |
/// - Parameters: | |
/// - statement: The prepared and performed SQLite statement. | |
/// - index: The 0-based index for the column being returned. | |
init?(from statement: Statement, index: Int32) | |
/// When binding a value to a prepared (but not yet performed) SQL statement. | |
/// | |
/// - Parameters: | |
/// - statement: The prepared SQLite statement to be performed. | |
/// - offset: the 1-based index for the column being bound. | |
/// - Returns: The SQLite return code. | |
func bind(to statement: Statement, offset: Int32) -> Int32 | |
} | |
// MARK: Specific type conformances | |
extension String: DatabaseBindable { | |
public init?(from statement: Statement, index: Int32) { | |
guard !statement.isNull(index: index), let pointer = sqlite3_column_text(statement.sqlite3_stmt, index) else { return nil } | |
self = String(cString: pointer) | |
} | |
public func bind(to statement: Statement, offset: Int32) -> Int32 { | |
sqlite3_bind_text(statement.sqlite3_stmt, offset, cString(using: .utf8), -1, SQLITE_TRANSIENT) | |
} | |
} | |
extension Decimal: DatabaseBindable { | |
public init?(from statement: Statement, index: Int32) { | |
guard | |
!statement.isNull(index: index), | |
let string = String(from: statement, index: index), | |
let value = Decimal(string: string, locale: .posix) else { return nil } | |
self = value | |
} | |
public func bind(to statement: Statement, offset: Int32) -> Int32 { | |
let string = self.formatted(Decimal.FormatStyle(locale: .posix)) | |
return sqlite3_bind_text(statement.sqlite3_stmt, offset, string.cString(using: .utf8), -1, SQLITE_TRANSIENT) | |
} | |
} | |
extension IntegerLiteralType: DatabaseBindable { | |
public init?(from statement: Statement, index: Int32) { | |
guard !statement.isNull(index: index) else { return nil } | |
let value = sqlite3_column_int64(statement.sqlite3_stmt, index) | |
self = .init(value) | |
} | |
public func bind(to statement: Statement, offset: Int32) -> Int32 { | |
sqlite3_bind_int64(statement.sqlite3_stmt, offset, Int64(self)) | |
} | |
} | |
extension BinaryFloatingPoint { | |
public init?(from statement: Statement, index: Int32) { | |
guard !statement.isNull(index: index) else { return nil } | |
self = Self(sqlite3_column_double(statement.sqlite3_stmt, index)) | |
} | |
public func bind(to statement: Statement, offset: Int32) -> Int32 { | |
sqlite3_bind_double(statement.sqlite3_stmt, offset, Double(self)) | |
} | |
} | |
extension Data: DatabaseBindable { | |
public init?(from statement: Statement, index: Int32) { | |
guard !statement.isNull(index: index) else { return nil } | |
let count = sqlite3_column_bytes(statement.sqlite3_stmt, index) | |
if count == 0 { return nil } | |
guard let bytes = sqlite3_column_blob(statement.sqlite3_stmt, index) else { return nil } | |
self = Data(bytes: bytes, count: Int(count)) | |
} | |
public func bind(to statement: Statement, offset: Int32) -> Int32 { | |
withUnsafeBytes { pointer in | |
let bytes = pointer.baseAddress | |
return sqlite3_bind_blob(statement.sqlite3_stmt, offset, bytes, Int32(count), SQLITE_TRANSIENT) | |
} | |
} | |
} | |
extension Date: DatabaseBindable { | |
public init?(from statement: Statement, index: Int32) { | |
guard | |
!statement.isNull(index: index), | |
let pointer = sqlite3_column_text(statement.sqlite3_stmt, index) else { return nil } | |
let string = String(cString: pointer) | |
guard let date = Database.dateFormatter.date(from: string) else { return nil } | |
self = date | |
} | |
public func bind(to statement: Statement, offset: Int32) -> Int32 { | |
let string = Database.dateFormatter.string(from: self) | |
return sqlite3_bind_text(statement.sqlite3_stmt, offset, string.cString(using: .utf8), -1, SQLITE_TRANSIENT) | |
} | |
} | |
// MARK: - Locale.posix | |
extension Locale { | |
static let posix = Locale(identifier: "en_US_POSIX") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For example, the above transforms code like:
Into just: