Fold work from SQLite into common

- Checkpoint before common moved to top level
This commit is contained in:
2025-02-17 12:05:26 -05:00
parent 3826009dfa
commit 4249159252
16 changed files with 332 additions and 394 deletions

View File

@@ -28,13 +28,15 @@ enum class AutoId {
/**
* Generate a string of random hex characters
*
* @param length The length of the string
* @param length The length of the string (optional; defaults to configured length)
* @return A string of random hex characters of the requested length
*/
fun generateRandomString(length: Int): String =
kotlin.random.Random.nextBytes((length + 2) / 2)
.joinToString("") { String.format("%02x", it) }
.substring(0, length)
fun generateRandomString(length: Int? = null): String =
(length ?: Configuration.idStringLength).let { len ->
kotlin.random.Random.nextBytes((len + 2) / 2)
.joinToString("") { String.format("%02x", it) }
.substring(0, len)
}
/**
* Determine if a document needs an automatic ID applied

View File

@@ -6,4 +6,20 @@ package solutions.bitbadger.documents.common
* @property op The operation for the field comparison
* @property value The value against which the comparison will be made
*/
class Comparison<T>(val op: Op, val value: T)
class Comparison<T>(val op: Op, val value: T) {
/** Is the value for this comparison a numeric value? */
val isNumeric: Boolean
get() =
if (op == Op.IN || op == Op.BETWEEN) {
val values = value as? Collection<*>
if (values.isNullOrEmpty()) {
false
} else {
val first = values.elementAt(0)
first is Byte || first is Short || first is Int || first is Long
}
} else {
value is Byte || value is Short || value is Int || value is Long
}
}

View File

@@ -26,8 +26,15 @@ object Configuration {
/** The length of automatic random hex character string */
var idStringLength = 16
/** The derived dialect value from the connection string */
private var dialectValue: Dialect? = null
/** The connection string for the JDBC connection */
var connectionString: String? = null
set(value) {
field = value
dialectValue = if (value.isNullOrBlank()) null else Dialect.deriveFromConnectionString(value)
}
/**
* Retrieve a new connection to the configured database
@@ -42,23 +49,14 @@ object Configuration {
return DriverManager.getConnection(connectionString)
}
private var dialectValue: Dialect? = null
/** The dialect in use */
val dialect: Dialect
get() {
if (dialectValue == null) {
if (connectionString == null) {
throw IllegalArgumentException("Please provide a connection string before attempting data access")
}
val it = connectionString!!
dialectValue = when {
it.contains("sqlite") -> Dialect.SQLITE
it.contains("postgresql") -> Dialect.POSTGRESQL
else -> throw IllegalArgumentException("Cannot determine dialect from [$it]")
}
}
return dialectValue!!
}
/**
* The dialect in use
*
* @param process The process being attempted
* @return The dialect for the current connection
* @throws DocumentException If the dialect has not been set
*/
fun dialect(process: String? = null): Dialect =
dialectValue ?: throw DocumentException(
"Database mode not set" + if (process == null) "" else "; cannot $process")
}

View File

@@ -7,5 +7,22 @@ enum class Dialect {
/** PostgreSQL */
POSTGRESQL,
/** SQLite */
SQLITE
SQLITE;
companion object {
/**
* Derive the dialect from the given connection string
*
* @param connectionString The connection string from which the dialect will be derived
* @return The dialect for the connection string
* @throws DocumentException If the dialect cannot be determined
*/
fun deriveFromConnectionString(connectionString: String): Dialect =
when {
connectionString.contains("sqlite") -> SQLITE
connectionString.contains("postgresql") -> POSTGRESQL
else -> throw DocumentException("Cannot determine dialect from [$connectionString]")
}
}
}

View File

@@ -8,7 +8,7 @@ package solutions.bitbadger.documents.common
* @property parameterName The name of the parameter to use in the query (optional, generated if missing)
* @property qualifier A table qualifier to use to address the `data` field (useful for multi-table queries)
*/
class Field<T>(
class Field<T> private constructor(
val name: String,
val comparison: Comparison<T>,
val parameterName: String? = null,
@@ -42,7 +42,53 @@ class Field<T>(
fun path(dialect: Dialect, format: FieldFormat = FieldFormat.SQL): String =
(if (qualifier == null) "" else "${qualifier}.") + nameToPath(name, dialect, format)
/** Parameters to bind each value of `IN` and `IN_ARRAY` operations */
private val inParameterNames: String
get() {
val values = if (comparison.op == Op.IN) {
comparison.value as Collection<*>
} else {
val parts = comparison.value as Pair<*, *>
parts.second as Collection<*>
}
return List(values.size) { idx -> "${parameterName}_$idx" }.joinToString(", ")
}
/**
* Create a `WHERE` clause fragment for this field
*
* @return The `WHERE` clause for this field
* @throws DocumentException If the field has no parameter name or the database dialect has not been set
*/
fun toWhere(): String {
if (parameterName == null && !listOf(Op.EXISTS, Op.NOT_EXISTS).contains(comparison.op))
throw DocumentException("Parameter for $name must be specified")
val dialect = Configuration.dialect("make field WHERE clause")
val fieldName = path(dialect, if (comparison.op == Op.IN_ARRAY) FieldFormat.JSON else FieldFormat.SQL)
val fieldPath = when (dialect) {
Dialect.POSTGRESQL -> if (comparison.isNumeric) "($fieldName)::numeric" else fieldName
Dialect.SQLITE -> fieldName
}
val criteria = when (comparison.op) {
in listOf(Op.EXISTS, Op.NOT_EXISTS) -> ""
Op.BETWEEN -> " ${parameterName}min AND ${parameterName}max"
Op.IN -> " ($inParameterNames)"
Op.IN_ARRAY -> if (dialect == Dialect.POSTGRESQL) " ARRAY['$inParameterNames']" else ""
else -> " $parameterName"
}
@Suppress("UNCHECKED_CAST")
return if (dialect == Dialect.SQLITE && comparison.op == Op.IN_ARRAY) {
val (table, _) = comparison.value as? Pair<String, *> ?: throw DocumentException("InArray field invalid")
"EXISTS (SELECT 1 FROM json_each($table.data, '$.$name') WHERE value IN ($inParameterNames)"
} else {
"$fieldPath ${comparison.op.sql} $criteria"
}
}
companion object {
/**
* Create a field equality comparison
*
@@ -50,8 +96,8 @@ class Field<T>(
* @param value The value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> equal(name: String, value: T): Field<T> =
Field<T>(name, Comparison(Op.EQUAL, value))
fun <T> equal(name: String, value: T) =
Field(name, Comparison(Op.EQUAL, value))
/**
* Create a field greater-than comparison
@@ -60,7 +106,7 @@ class Field<T>(
* @param value The value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> greater(name: String, value: T): Field<T> =
fun <T> greater(name: String, value: T) =
Field(name, Comparison(Op.GREATER, value))
/**
@@ -70,7 +116,7 @@ class Field<T>(
* @param value The value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> greaterOrEqual(name: String, value: T): Field<T> =
fun <T> greaterOrEqual(name: String, value: T) =
Field(name, Comparison(Op.GREATER_OR_EQUAL, value))
/**
@@ -80,7 +126,7 @@ class Field<T>(
* @param value The value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> less(name: String, value: T): Field<T> =
fun <T> less(name: String, value: T) =
Field(name, Comparison(Op.LESS, value))
/**
@@ -90,7 +136,7 @@ class Field<T>(
* @param value The value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> lessOrEqual(name: String, value: T): Field<T> =
fun <T> lessOrEqual(name: String, value: T) =
Field(name, Comparison(Op.LESS_OR_EQUAL, value))
/**
@@ -100,7 +146,7 @@ class Field<T>(
* @param value The value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> notEqual(name: String, value: T): Field<T> =
fun <T> notEqual(name: String, value: T) =
Field(name, Comparison(Op.NOT_EQUAL, value))
/**
@@ -111,7 +157,7 @@ class Field<T>(
* @param maxValue The upper value for the comparison
* @return A `Field` with the given comparison
*/
fun <T> between(name: String, minValue: T, maxValue: T): Field<Pair<T, T>> =
fun <T> between(name: String, minValue: T, maxValue: T) =
Field(name, Comparison(Op.BETWEEN, Pair(minValue, maxValue)))
/**
@@ -121,7 +167,7 @@ class Field<T>(
* @param values The values for the comparison
* @return A `Field` with the given comparison
*/
fun <T> any(name: String, values: List<T>): Field<List<T>> =
fun <T> any(name: String, values: List<T>) =
Field(name, Comparison(Op.IN, values))
/**
@@ -132,16 +178,16 @@ class Field<T>(
* @param values The values for the comparison
* @return A `Field` with the given comparison
*/
fun <T> inArray(name: String, tableName: String, values: List<T>): Field<Pair<String, List<T>>> =
fun <T> inArray(name: String, tableName: String, values: List<T>) =
Field(name, Comparison(Op.IN_ARRAY, Pair(tableName, values)))
fun exists(name: String): Field<String> =
fun exists(name: String) =
Field(name, Comparison(Op.EXISTS, ""))
fun notExists(name: String): Field<String> =
fun notExists(name: String) =
Field(name, Comparison(Op.NOT_EXISTS, ""))
fun named(name: String): Field<String> =
fun named(name: String) =
Field(name, Comparison(Op.EQUAL, ""))
fun nameToPath(name: String, dialect: Dialect, format: FieldFormat): String {

View File

@@ -1,14 +0,0 @@
//TIP To <b>Run</b> code, press <shortcut actionId="Run"/> or
// click the <icon src="AllIcons.Actions.Execute"/> icon in the gutter.
fun main() {
val name = "Kotlin"
//TIP Press <shortcut actionId="ShowIntentionActions"/> with your caret at the highlighted text
// to see how IntelliJ IDEA suggests fixing it.
println("Hello, " + name + "!")
for (i in 1..5) {
//TIP Press <shortcut actionId="Debug"/> to start debugging your code. We have set one <icon src="AllIcons.Debugger.Db_set_breakpoint"/> breakpoint
// for you, but you can always add more by pressing <shortcut actionId="ToggleLineBreakpoint"/>.
println("i = $i")
}
}

View File

@@ -12,6 +12,19 @@ import java.sql.Types
*/
object Parameters {
/**
* Assign parameter names to any fields that do not have them assigned
*
* @param fields The collection of fields to be named
* @return The collection of fields with parameter names assigned
*/
fun nameFields(fields: Collection<Field<*>>): Collection<Field<*>> {
val name = ParameterName()
return fields.map {
if (it.name.isBlank()) it.withParameterName(name.derive(null)) else it
}
}
/**
* Replace the parameter names in the query with question marks
*

View File

@@ -9,9 +9,84 @@ object Query {
* @param where The `WHERE` clause for the statement
* @return The two parts of the query combined with `WHERE`
*/
fun statementWhere(statement: String, where: String): String =
fun statementWhere(statement: String, where: String) =
"$statement WHERE $where"
/**
* Functions to create `WHERE` clause fragments
*/
object Where {
/**
* Create a `WHERE` clause fragment to query by one or more fields
*
* @param fields The fields to be queried
* @param howMatched How the fields should be matched (optional, defaults to `ALL`)
* @return A `WHERE` clause fragment to match the given fields
*/
fun byFields(fields: Collection<Field<*>>, howMatched: FieldMatch? = null) =
fields.joinToString(" ${(howMatched ?: FieldMatch.ALL).sql} ") { it.toWhere() }
/**
* Create a `WHERE` clause fragment to retrieve a document by its ID
*
* @param parameterName The parameter name to use for the ID placeholder (optional, defaults to ":id")
* @param docId The ID value (optional; used for type determinations, string assumed if not provided)
*/
fun <TKey> byId(parameterName: String = ":id", docId: TKey? = null) =
byFields(listOf(Field.equal(Configuration.idField, docId ?: "").withParameterName(parameterName)))
/**
* Create a `WHERE` clause fragment to implement a JSON containment query (PostgreSQL only)
*
* @param parameterName The parameter name to use for the JSON placeholder (optional, defaults to ":criteria")
* @return A `WHERE` clause fragment to implement a JSON containment criterion
* @throws DocumentException If called against a SQLite database
*/
fun jsonContains(parameterName: String = ":criteria") =
when (Configuration.dialect("create containment WHERE clause")) {
Dialect.POSTGRESQL -> "data @> $parameterName"
Dialect.SQLITE -> throw DocumentException("JSON containment is not supported")
}
/**
* Create a `WHERE` clause fragment to implement a JSON path match query (PostgreSQL only)
*
* @param parameterName The parameter name to use for the placeholder (optional, defaults to ":path")
* @return A `WHERE` clause fragment to implement a JSON path match criterion
* @throws DocumentException If called against a SQLite database
*/
fun jsonPathMatches(parameterName: String = ":path") =
when (Configuration.dialect("create JSON path match WHERE clause")) {
Dialect.POSTGRESQL -> "jsonb_path_exists(data, $parameterName::jsonpath)"
Dialect.SQLITE -> throw DocumentException("JSON path match is not supported")
}
}
/**
* Create a query by a document's ID
*
* @param statement The SQL statement to be run against a document by its ID
* @param docId The ID of the document targeted
* @returns A query addressing a document by its ID
*/
fun <TKey> byId(statement: String, docId: TKey) =
statementWhere(statement, Where.byId(docId = docId))
/**
* Create a query on JSON fields
*
* @param statement The SQL statement to be run against matching fields
* @param howMatched Whether to match any or all of the field conditions
* @param fields The field conditions to be matched
* @return A query addressing documents by field matching conditions
*/
fun byFields(statement: String, howMatched: FieldMatch, fields: Collection<Field<*>>) =
Query.statementWhere(statement, Where.byFields(fields, howMatched))
/**
* Functions to create queries to define tables and indexes
*/
object Definition {
/**
@@ -24,6 +99,18 @@ object Query {
fun ensureTableFor(tableName: String, dataType: String): String =
"CREATE TABLE IF NOT EXISTS $tableName (data $dataType NOT NULL)"
/**
* SQL statement to create a document table in the current dialect
*
* @param tableName The name of the table to create (may include schema)
* @return A query to create a document table
*/
fun ensureTable(tableName: String) =
when (Configuration.dialect("create table creation query")) {
Dialect.POSTGRESQL -> ensureTableFor(tableName, "JSONB")
Dialect.SQLITE -> ensureTableFor(tableName, "TEXT")
}
/**
* Split a schema and table name
*
@@ -71,8 +158,27 @@ object Query {
* @param tableName The table into which to insert (may include schema)
* @return A query to insert a document
*/
fun insert(tableName: String): String =
"INSERT INTO $tableName VALUES (:data)"
fun insert(tableName: String, autoId: AutoId? = null): String {
val id = Configuration.idField
val values = when (Configuration.dialect("create INSERT statement")) {
Dialect.POSTGRESQL -> when (autoId ?: AutoId.DISABLED) {
AutoId.DISABLED -> ":data"
AutoId.NUMBER -> ":data::jsonb || ('{\"$id\":' || " +
"(SELECT COALESCE(MAX((data->>'$id')::numeric), 0) + 1 " +
"FROM $tableName) || '}')::jsonb"
AutoId.UUID -> ":data::jsonb || '{\"$id\":\"${AutoId.generateUUID()}\"}'"
AutoId.RANDOM_STRING -> ":data::jsonb || '{\"$id\":\"${AutoId.generateRandomString()}\"}'"
}
Dialect.SQLITE -> when (autoId ?: AutoId.DISABLED) {
AutoId.DISABLED -> ":data"
AutoId.NUMBER -> "json_set(:data, '$.$id', " +
"(SELECT coalesce(max(data->>'$id'), 0) + 1 FROM $tableName))"
AutoId.UUID -> "json_set(:data, '$.$id', '${AutoId.generateUUID()}')"
AutoId.RANDOM_STRING -> "json_set(:data, '$.$id', '${AutoId.generateRandomString()}')"
}
}
return "INSERT INTO $tableName VALUES ($values)"
}
/**
* Query to save a document, inserting it if it does not exist and updating it if it does (AKA "upsert")
@@ -81,7 +187,8 @@ object Query {
* @return A query to save a document
*/
fun save(tableName: String): String =
"${insert(tableName)} ON CONFLICT ((data->>'${Configuration.idField}')) DO UPDATE SET data = EXCLUDED.data"
insert(tableName, AutoId.DISABLED) +
" ON CONFLICT ((data->>'${Configuration.idField}')) DO UPDATE SET data = EXCLUDED.data"
/**
* Query to count documents in a table (this query has no `WHERE` clause)
@@ -120,6 +227,66 @@ object Query {
fun update(tableName: String): String =
"UPDATE $tableName SET data = :data"
/**
* Functions to create queries to patch (partially update) JSON documents
*/
object Patch {
/**
* Create an `UPDATE` statement to patch documents
*
* @param tableName The table to be updated
* @param where The `WHERE` clause for the query
* @return A query to patch documents
*/
private fun patch(tableName: String, where: String): String {
val setValue = when (Configuration.dialect("create patch query")) {
Dialect.POSTGRESQL -> "data || :data"
Dialect.SQLITE -> "json_patch(data, json(:data))"
}
return statementWhere("UPDATE $tableName SET data = $setValue", where)
}
/**
* A query to patch (partially update) a JSON document by its ID
*
* @param tableName The name of the table where the document is stored
* @param docId The ID of the document to be updated (optional, used for type checking)
* @return A query to patch a JSON document by its ID
*/
fun <TKey> byId(tableName: String, docId: TKey? = null) =
patch(tableName, Where.byId(docId = docId))
/**
* A query to patch (partially update) a JSON document using field match criteria
*
* @param tableName The name of the table where the documents are stored
* @param fields The field criteria
* @param howMatched How the fields should be matched (optional, defaults to `ALL`)
* @return A query to patch JSON documents by field match criteria
*/
fun byFields(tableName: String, fields: Collection<Field<*>>, howMatched: FieldMatch? = null) =
patch(tableName, Where.byFields(fields, howMatched))
/**
* A query to patch (partially update) a JSON document by JSON containment (PostgreSQL only)
*
* @param tableName The name of the table where the document is stored
* @return A query to patch JSON documents by JSON containment
*/
fun <TKey> byContains(tableName: String) =
patch(tableName, Where.jsonContains())
/**
* A query to patch (partially update) a JSON document by JSON path match (PostgreSQL only)
*
* @param tableName The name of the table where the document is stored
* @return A query to patch JSON documents by JSON path match
*/
fun <TKey> byJsonPath(tableName: String) =
patch(tableName, Where.jsonPathMatches())
}
/**
* Query to delete documents from a table (this query has no `WHERE` clause)
*

View File

@@ -56,7 +56,7 @@ object Results {
* @return The count from the row
*/
fun toCount(rs: ResultSet): Long =
when (Configuration.dialect) {
when (Configuration.dialect()) {
Dialect.POSTGRESQL -> rs.getInt("it").toLong()
Dialect.SQLITE -> rs.getLong("it")
}
@@ -68,7 +68,7 @@ object Results {
* @return The true/false value from the row
*/
fun toExists(rs: ResultSet): Boolean =
when (Configuration.dialect) {
when (Configuration.dialect()) {
Dialect.POSTGRESQL -> rs.getBoolean("it")
Dialect.SQLITE -> toCount(rs) > 0L
}

View File

@@ -67,14 +67,30 @@ class QueryTest {
@Test
@DisplayName("insert generates correctly")
fun insert() {
assertEquals("INSERT INTO $tbl VALUES (:data)", Query.insert(tbl), "INSERT statement not constructed correctly")
try {
Configuration.connectionString = "postgresql"
assertEquals(
"INSERT INTO $tbl VALUES (:data)",
Query.insert(tbl),
"INSERT statement not constructed correctly"
)
} finally {
Configuration.connectionString = null
}
}
@Test
@DisplayName("save generates correctly")
fun save() {
assertEquals("INSERT INTO $tbl VALUES (:data) ON CONFLICT ((data->>'id')) DO UPDATE SET data = EXCLUDED.data",
Query.save(tbl), "INSERT ON CONFLICT UPDATE statement not constructed correctly")
try {
Configuration.connectionString = "postgresql"
assertEquals(
"INSERT INTO $tbl VALUES (:data) ON CONFLICT ((data->>'id')) DO UPDATE SET data = EXCLUDED.data",
Query.save(tbl), "INSERT ON CONFLICT UPDATE statement not constructed correctly"
)
} finally {
Configuration.connectionString = null
}
}
@Test