commit c875c42dafa144b50d63388b97fe7e4c719022af Author: niansa Date: Tue Oct 18 00:35:00 2022 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cf25ac6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +CMakeLists.txt.user* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e0aeebb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "DPP"] + path = DPP + url = https://github.com/brainboxdotcc/DPP.git +[submodule "libjustgpt"] + path = libjustgpt + url = https://gitlab.com/niansa/libjustgpt.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..01e79f0 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.5) + +project(KeineAhnung LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_subdirectory(DPP) +add_subdirectory(libjustgpt) + +add_compile_definitions(COMPILER_ID="${CMAKE_CXX_COMPILER_ID}") +add_compile_definitions(COMPILER_VERSION="${CMAKE_CXX_COMPILER_VERSION}") +add_compile_definitions(COMPILER_PLATFORM="${CMAKE_CXX_PLATFORM_ID}") + +add_executable(KeineAhnung main.cpp) +target_link_libraries(KeineAhnung PRIVATE dpp justgpt sqlite3) + +install(TARGETS KeineAhnung + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/CMakeLists.txt.user b/CMakeLists.txt.user new file mode 100644 index 0000000..847e873 --- /dev/null +++ b/CMakeLists.txt.user @@ -0,0 +1,416 @@ + + + + + + EnvironmentId + {f3929b0b-3d39-4fa3-8d2d-2b329b63b30c} + + + ProjectExplorer.Project.ActiveTarget + 0 + + + ProjectExplorer.Project.EditorSettings + + true + false + true + + Cpp + + CppGlobal + + + + QmlJS + + QmlJSGlobal + + + 2 + UTF-8 + false + 4 + false + 80 + true + true + 1 + false + true + false + 0 + true + true + 0 + 8 + true + false + 1 + true + true + true + *.md, *.MD, Makefile + false + true + + + + ProjectExplorer.Project.PluginSettings + + + true + false + true + true + true + true + + + 0 + true + + true + true + Builtin.DefaultTidyAndClazy + 2 + + + + true + + + + + ProjectExplorer.Project.Target.0 + + Desktop + Desktop + Desktop + {913660d6-ca1c-4b66-a4da-64108a3258a2} + 0 + 0 + 0 + + Debug + -DCMAKE_GENERATOR:STRING=Unix Makefiles +-DCMAKE_BUILD_TYPE:STRING=Debug +-DCMAKE_PROJECT_INCLUDE_BEFORE:FILEPATH=%{IDE:ResourcePath}/package-manager/auto-setup.cmake +-DQT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} +-DCMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} +-DCMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} +-DCMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Debug + + + + all + + true + Build + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + clean + + true + Build + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + false + + Debug + CMakeProjectManager.CMakeBuildConfiguration + + + Release + -DCMAKE_GENERATOR:STRING=Unix Makefiles +-DCMAKE_BUILD_TYPE:STRING=Release +-DCMAKE_PROJECT_INCLUDE_BEFORE:FILEPATH=%{IDE:ResourcePath}/package-manager/auto-setup.cmake +-DQT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} +-DCMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} +-DCMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} +-DCMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Release + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + false + + Release + CMakeProjectManager.CMakeBuildConfiguration + + + RelWithDebInfo + -DCMAKE_GENERATOR:STRING=Unix Makefiles +-DCMAKE_BUILD_TYPE:STRING=RelWithDebInfo +-DCMAKE_PROJECT_INCLUDE_BEFORE:FILEPATH=%{IDE:ResourcePath}/package-manager/auto-setup.cmake +-DQT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} +-DCMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} +-DCMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} +-DCMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-RelWithDebInfo + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + false + + Release with Debug Information + CMakeProjectManager.CMakeBuildConfiguration + + + RelWithDebInfo + -DCMAKE_GENERATOR:STRING=Unix Makefiles +-DCMAKE_BUILD_TYPE:STRING=RelWithDebInfo +-DCMAKE_PROJECT_INCLUDE_BEFORE:FILEPATH=%{IDE:ResourcePath}/package-manager/auto-setup.cmake +-DQT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} +-DCMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} +-DCMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} +-DCMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + 0 + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Profile + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + false + + Profile + CMakeProjectManager.CMakeBuildConfiguration + + + MinSizeRel + -DCMAKE_GENERATOR:STRING=Unix Makefiles +-DCMAKE_BUILD_TYPE:STRING=MinSizeRel +-DCMAKE_PROJECT_INCLUDE_BEFORE:FILEPATH=%{IDE:ResourcePath}/package-manager/auto-setup.cmake +-DQT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} +-DCMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} +-DCMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} +-DCMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-MinSizeRel + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + false + + Minimum Size Release + CMakeProjectManager.CMakeBuildConfiguration + + 5 + + + 0 + Deploy + Deploy + ProjectExplorer.BuildSteps.Deploy + + 1 + + false + ProjectExplorer.DefaultDeployConfiguration + + 1 + + true + true + true + + 2 + + KeineAhnung + CMakeProjectManager.CMakeRunConfiguration.KeineAhnung + KeineAhnung + MTAzMTU4Njc1MDU3NDg5OTI1Mg.GFqCYM.9qMg0Dwzpqt1EKczOLq-bWQgbXFXPJsZ_f9u54 + false + true + true + false + true + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Debug + + + true + true + true + + 2 + + soaktest + CMakeProjectManager.CMakeRunConfiguration.soaktest + soaktest + false + true + true + false + true + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Debug/DPP/library + + + true + true + true + + 2 + + unittest + CMakeProjectManager.CMakeRunConfiguration.unittest + unittest + false + true + true + false + true + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Debug/DPP/library + + + true + true + true + + 2 + + justgpt_test + CMakeProjectManager.CMakeRunConfiguration.justgpt_test + justgpt_test + false + true + true + false + true + /mnt/hhdd/Programme/OSS/build-KeineAhnung-Desktop-Debug/libjustgpt + + 4 + + + + ProjectExplorer.Project.TargetCount + 1 + + + ProjectExplorer.Project.Updater.FileVersion + 22 + + + Version + 22 + + diff --git a/DPP b/DPP new file mode 160000 index 0000000..f3c3a1a --- /dev/null +++ b/DPP @@ -0,0 +1 @@ +Subproject commit f3c3a1a9d693a5fc5debedd1dc363e082c907645 diff --git a/libjustgpt b/libjustgpt new file mode 160000 index 0000000..c0cddb2 --- /dev/null +++ b/libjustgpt @@ -0,0 +1 @@ +Subproject commit c0cddb2e598137225e478d7201a04c590c8962e0 diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..072132a --- /dev/null +++ b/main.cpp @@ -0,0 +1,308 @@ +#include +#include +#include +#include +#include +#include +#include +#ifdef __linux__ +# include +# include +#endif + +#include +#include "sqlite_modern_cpp/sqlite_modern_cpp.h" +#include + + + +const std::string_view ai_init_prompt = +R"( Hello + Hi! + How are you? + Great! + What is your favorite color? + Green! + What is your name? + I don't have a name. + How are you? + Good. + )"; + + +template +std::string intAsHex(intT i) { + std::ostringstream sstream; + sstream << "0x" << std::hex << i; + return std::move(sstream).str(); +} + +uint32_t getColorOfUser(const dpp::user& user) { + uint32_t colors[] = { + 0x424cb6, + 0x585f68, + 0x2c7c45, + 0xbc7d14, + 0xb23234 + }; + return colors[user.discriminator % 5]; +} + +namespace SysInfo { // TODO: This currently works on linux only +# ifdef __linux__ +# define SYSINFO_WORKS + size_t memUsed() { + const std::string identifier = "VmData:\t"; + std::ifstream pstatus("/proc/self/status"); + // Check for success + if (not pstatus) { + return 0; + } + // Find line + std::string currline; + while (currline.find(identifier) != 0) { + std::getline(pstatus, currline); + } + // Close pstatus + pstatus.close(); + // Get number + size_t res = 0; + for (const auto& character : currline) { + // Check if character is digit + if (character> '9' or character < '0') { + continue; + } + // Append digit to res + res += static_cast(character - '0'); + res *= 10; + } + res /= 10; + // Return result in MB + return res / 1000; + } + + size_t memTotal() { + // Get sysinfo + struct sysinfo i; + sysinfo(&i); + // Get total ram in MB + return i.totalram / 1000000; + } + + std::string kernelNameAndVersion() { + struct utsname name; + uname(&name); + return std::string(name.sysname)+" "+std::string(name.release); + } +# endif +} + + +int main(int argc, char **argv) { + // Caches + std::vector slashcommandCache; // Without IDs + std::unordered_map guildCache; + + // Initialize AI + Ai ai; + + auto chat_reply = [&](const std::string& message) -> std::string { + auto prompt = std::string(ai_init_prompt)+message+"\n"; + return ai.complete(prompt, '\n'); + }; + + // Initialize database + sqlite::database db("db.sqlite3"); + db << "CREATE TABLE IF NOT EXISTS guilds (" + " id TEXT PRIMARY KEY NOT NULL," + " globalchat_channel TEXT," + " UNIQUE(id)" + ");"; + + // Initialize bot + dpp::cluster bot(argv[1]); + + auto broadcast_message = [&](const dpp::message_create_t& message) { + dpp::message msg; + const auto& guild = guildCache.at(message.msg.guild_id); + auto messageCode = intAsHex(message.msg.id)+intAsHex(message.msg.channel_id)+std::to_string(message.msg.guild_id); + msg.add_embed(dpp::embed() + .set_title(":earth_africa: Globalchat") + .set_description(":speaking_head: ["+message.msg.author.format_username()+"](https://internal.tuxifan.net/discordInspectionData/"+messageCode+" 'Inspection Data')") + .set_thumbnail(message.msg.author.get_avatar_url()) + .set_color(getColorOfUser(message.msg.author)) + .add_field(message.msg.member.nickname.empty()?message.msg.author.username:message.msg.member.nickname, message.msg.content+"\n⠀") + .set_footer(dpp::embed_footer().set_text(guild.name)) + ); + db << "SELECT globalchat_channel FROM guilds;" + >> [&](std::string channel_id_str) { + if (channel_id_str.empty()) + return; + msg.set_channel_id(std::stoul(channel_id_str)); + bot.message_create(msg); + }; + }; + + bot.on_log(dpp::utility::cout_logger()); + + bot.on_slashcommand([&](const dpp::slashcommand_t& event) { + // Help + if (event.command.get_command_name() == "help") { + // Help command + dpp::embed embed; + embed.set_title("Help") + .set_description("Commands") + .set_footer(dpp::embed_footer().set_text("Bot made by 353535#3535")); + for (const auto& command : slashcommandCache) { + embed.add_field(command.name, command.description, true); + } + event.reply(dpp::message().add_embed(embed)); + } else if (event.command.get_command_name() == "about") { + // Info command + std::ostringstream info; + info << "**Peak server count**: " << guildCache.size() << "\n" +# ifdef SYSINFO_WORKS + "**VM memory usage**: " << SysInfo::memUsed() << " MB / " << SysInfo::memTotal() << " MB\n" + "**Kernel**: " << SysInfo::kernelNameAndVersion() << "\n" +# endif + "**Compiler**: " COMPILER_ID " " COMPILER_VERSION " (" COMPILER_PLATFORM ")\n" + "**C++ standard**: " << __cplusplus << "\n" + ; + event.reply(dpp::message() + .add_embed(dpp::embed() + .set_title("about") + .set_description(std::move(info).str()) + )); + } + + // Fun + else if (event.command.get_command_name() == "ping") { + // Ping command + event.reply("Pong!"); + } else if (event.command.get_command_name() == "chat") { + // Chat command + auto& content = std::get(event.get_parameter("message")); + event.thinking(false); + std::thread([=, &bot]() { + auto reply = chat_reply(content); + event.delete_original_response(); + bot.message_create(dpp::message(event.command.channel_id, event.command.member.get_mention()+" asked:\n> "+content+"\nI reply:\n> "+reply)); + }).detach(); + } + + // Global chat + else if (event.command.get_command_name() == "set_gc") { + // Set global chat command + auto channel_id = std::get(event.get_parameter("channel")); + // Update globalchat ID + db << "UPDATE guilds " + "SET globalchat_channel = ? " + "WHERE id = ?;" + << std::to_string(channel_id) + << std::to_string(event.command.guild_id); + // Reply + event.reply("Globalchat is now in <#"+std::to_string(channel_id)+">"); + } else if (event.command.get_command_name() == "reset_gc") { + // Update globalchat ID + db << "UPDATE guilds " + "SET globalchat_channel = NULL " + "WHERE id = ?;" + << std::to_string(event.command.guild_id); + // Reply + event.reply("Globalchat reset"); + } + }); + bot.on_message_create([&](const dpp::message_create_t& message) { + // Make sure sender isn't a bot + if (message.msg.author.is_bot()) + return; + + // Chat with bot + if (message.msg.content[0] == '-') { + auto content = message.msg.content; + content.erase(0, 1); + for (auto& c : content) { + if (c == '\n') + c = ' '; + } + auto reply = chat_reply(content); + message.reply(reply); + return; + } + + // Global chat + { + // Get current globalchat channel for server + dpp::snowflake globalchat_channel; + { + std::string globalchat_channel_str; + db << "SELECT globalchat_channel FROM guilds " + "WHERE id = ?;" + << std::to_string(message.msg.guild_id) + >> std::tie(globalchat_channel_str); + globalchat_channel = std::stoul(globalchat_channel_str); + } + // If this channel is globalchat channel, broadcast message + if (message.msg.channel_id == globalchat_channel) { + broadcast_message(message); + bot.message_delete(message.msg.id, message.msg.channel_id); + return; + } + } + }); + bot.on_guild_create([&](const dpp::guild_create_t& guild) { + // Make sure guild is in database + db << "INSERT OR IGNORE INTO guilds (id) VALUES (?);" + << std::to_string(guild.created->id); + // Add guild to cache + guildCache[guild.created->id] = *guild.created; + }); + + bot.on_ready([&](const dpp::ready_t& event) { + // Register bot commands once + if (dpp::run_once()) { + // Clear all commands first + for (const auto& command : bot.global_commands_get_sync()) { + bot.global_command_delete_sync(command.first); + } + slashcommandCache.clear(); + // Help + { + dpp::slashcommand sc("help", "Get help", bot.me.id); + bot.global_command_create(sc); + slashcommandCache.push_back(sc); + } + { + dpp::slashcommand sc("about", "About me", bot.me.id); + bot.global_command_create(sc); + slashcommandCache.push_back(sc); + } + // Fun + { + dpp::slashcommand sc("ping", "Ping pong!", bot.me.id); + bot.global_command_create(sc); + slashcommandCache.push_back(sc); + }{ + dpp::slashcommand sc("chat", "Chat with the bot", bot.me.id); + sc.add_option(dpp::command_option(dpp::command_option_type::co_string, "message", "Your message", true)); + bot.global_command_create(sc); + slashcommandCache.push_back(sc); + } + // Global chat + { + dpp::slashcommand sc("set_gc", "Set global chat location", bot.me.id); + sc.add_option(dpp::command_option(dpp::command_option_type::co_channel, "channel", "Channel", true)); + bot.global_command_create(sc); + slashcommandCache.push_back(sc); + }{ + dpp::slashcommand sc("reset_gc", "Reset global chat location", bot.me.id); + bot.global_command_create(sc); + slashcommandCache.push_back(sc); + } + } + }); + + // Configure intents and start bot + bot.intents |= dpp::intents::i_message_content | dpp::intents::i_guilds; + bot.start(dpp::st_wait); +} diff --git a/sqlite_modern_cpp/License.txt b/sqlite_modern_cpp/License.txt new file mode 100644 index 0000000..595b1d6 --- /dev/null +++ b/sqlite_modern_cpp/License.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 aminroosta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sqlite_modern_cpp/sqlite_modern_cpp.h b/sqlite_modern_cpp/sqlite_modern_cpp.h new file mode 100644 index 0000000..344dedc --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp.h @@ -0,0 +1,1053 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#define MODERN_SQLITE_VERSION 3002008 + +#ifdef __has_include +#if __cplusplus > 201402 && __has_include() +#define MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#elif __has_include() +#define MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT +#endif +#endif + +#ifdef __has_include +#if __cplusplus > 201402 && __has_include() +#define MODERN_SQLITE_STD_VARIANT_SUPPORT +#endif +#endif + +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#include +#endif + +#ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT +#include +#define MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#endif + +#ifdef _MODERN_SQLITE_BOOST_OPTIONAL_SUPPORT +#include +#endif + +#ifdef __ENVIRONMENT_IPHONE_OS_VERSION_MIN_REQUIRED__ +#if __ENVIRONMENT_IPHONE_OS_VERSION_MIN_REQUIRED__ < 100000 +#undef __cpp_lib_uncaught_exceptions +#endif +#endif + +#include + +#include "sqlite_modern_cpp/errors.h" +#include "sqlite_modern_cpp/utility/function_traits.h" +#include "sqlite_modern_cpp/utility/uncaught_exceptions.h" +#include "sqlite_modern_cpp/utility/utf16_utf8.h" + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT +#include "sqlite_modern_cpp/utility/variant.h" +#endif + +namespace sqlite { + + // std::optional support for NULL values + #ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + template + using optional = std::experimental::optional; + #else + template + using optional = std::optional; + #endif + #endif + + class database; + class database_binder; + + template class binder; + + typedef std::shared_ptr connection_type; + + template::value == Element)> struct tuple_iterate { + static void iterate(Tuple& t, database_binder& db) { + get_col_from_db(db, Element, std::get(t)); + tuple_iterate::iterate(t, db); + } + }; + + template struct tuple_iterate { + static void iterate(Tuple&, database_binder&) {} + }; + + class database_binder { + + public: + // database_binder is not copyable + database_binder() = delete; + database_binder(const database_binder& other) = delete; + database_binder& operator=(const database_binder&) = delete; + + database_binder(database_binder&& other) : + _db(std::move(other._db)), + _stmt(std::move(other._stmt)), + _inx(other._inx), execution_started(other.execution_started) { } + + void execute() { + _start_execute(); + int hresult; + + while((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) {} + + if(hresult != SQLITE_DONE) { + errors::throw_sqlite_error(hresult, sql()); + } + } + + std::string sql() { +#if SQLITE_VERSION_NUMBER >= 3014000 + auto sqlite_deleter = [](void *ptr) {sqlite3_free(ptr);}; + std::unique_ptr str(sqlite3_expanded_sql(_stmt.get()), sqlite_deleter); + return str ? str.get() : original_sql(); +#else + return original_sql(); +#endif + } + + std::string original_sql() { + return sqlite3_sql(_stmt.get()); + } + + void used(bool state) { + if(!state) { + // We may have to reset first if we haven't done so already: + _next_index(); + --_inx; + } + execution_started = state; + } + bool used() const { return execution_started; } + + private: + std::shared_ptr _db; + std::unique_ptr _stmt; + utility::UncaughtExceptionDetector _has_uncaught_exception; + + int _inx; + + bool execution_started = false; + + int _next_index() { + if(execution_started && !_inx) { + sqlite3_reset(_stmt.get()); + sqlite3_clear_bindings(_stmt.get()); + } + return ++_inx; + } + void _start_execute() { + _next_index(); + _inx = 0; + used(true); + } + + void _extract(std::function call_back) { + int hresult; + _start_execute(); + + while((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) { + call_back(); + } + + if(hresult != SQLITE_DONE) { + errors::throw_sqlite_error(hresult, sql()); + } + } + + void _extract_single_value(std::function call_back) { + int hresult; + _start_execute(); + + if((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) { + call_back(); + } else if(hresult == SQLITE_DONE) { + throw errors::no_rows("no rows to extract: exactly 1 row expected", sql(), SQLITE_DONE); + } + + if((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) { + throw errors::more_rows("not all rows extracted", sql(), SQLITE_ROW); + } + + if(hresult != SQLITE_DONE) { + errors::throw_sqlite_error(hresult, sql()); + } + } + + sqlite3_stmt* _prepare(const std::u16string& sql) { + return _prepare(utility::utf16_to_utf8(sql)); + } + + sqlite3_stmt* _prepare(const std::string& sql) { + int hresult; + sqlite3_stmt* tmp = nullptr; + const char *remaining; + hresult = sqlite3_prepare_v2(_db.get(), sql.data(), -1, &tmp, &remaining); + if(hresult != SQLITE_OK) errors::throw_sqlite_error(hresult, sql); + if(!std::all_of(remaining, sql.data() + sql.size(), [](char ch) {return std::isspace(ch);})) + throw errors::more_statements("Multiple semicolon separated statements are unsupported", sql); + return tmp; + } + + template + struct is_sqlite_value : public std::integral_constant< + bool, + std::is_floating_point::value + || std::is_integral::value + || std::is_same::value + || std::is_same::value + || std::is_same::value + > { }; + template + struct is_sqlite_value< std::vector > : public std::integral_constant< + bool, + std::is_floating_point::value + || std::is_integral::value + || std::is_same::value + > { }; +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template + struct is_sqlite_value< std::variant > : public std::integral_constant< + bool, + true + > { }; +#endif + + + /* for vector support */ + template friend database_binder& operator <<(database_binder& db, const std::vector& val); + template friend void get_col_from_db(database_binder& db, int inx, std::vector& val); + /* for nullptr & unique_ptr support */ + friend database_binder& operator <<(database_binder& db, std::nullptr_t); + template friend database_binder& operator <<(database_binder& db, const std::unique_ptr& val); + template friend void get_col_from_db(database_binder& db, int inx, std::unique_ptr& val); +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template friend database_binder& operator <<(database_binder& db, const std::variant& val); + template friend void get_col_from_db(database_binder& db, int inx, std::variant& val); +#endif + template friend T operator++(database_binder& db, int); + // Overload instead of specializing function templates (http://www.gotw.ca/publications/mill17.htm) + friend database_binder& operator<<(database_binder& db, const int& val); + friend void get_col_from_db(database_binder& db, int inx, int& val); + friend database_binder& operator <<(database_binder& db, const sqlite_int64& val); + friend void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i); + friend database_binder& operator <<(database_binder& db, const float& val); + friend void get_col_from_db(database_binder& db, int inx, float& f); + friend database_binder& operator <<(database_binder& db, const double& val); + friend void get_col_from_db(database_binder& db, int inx, double& d); + friend void get_col_from_db(database_binder& db, int inx, std::string & s); + friend database_binder& operator <<(database_binder& db, const std::string& txt); + friend void get_col_from_db(database_binder& db, int inx, std::u16string & w); + friend database_binder& operator <<(database_binder& db, const std::u16string& txt); + + +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template friend database_binder& operator <<(database_binder& db, const optional& val); + template friend void get_col_from_db(database_binder& db, int inx, optional& o); +#endif + +#ifdef _MODERN_SQLITE_BOOST_OPTIONAL_SUPPORT + template friend database_binder& operator <<(database_binder& db, const boost::optional& val); + template friend void get_col_from_db(database_binder& db, int inx, boost::optional& o); +#endif + + public: + + database_binder(std::shared_ptr db, std::u16string const & sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + database_binder(std::shared_ptr db, std::string const & sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + ~database_binder() noexcept(false) { + /* Will be executed if no >>op is found, but not if an exception + is in mid flight */ + if(!used() && !_has_uncaught_exception && _stmt) { + execute(); + } + } + + template + typename std::enable_if::value, void>::type operator>>( + Result& value) { + this->_extract_single_value([&value, this] { + get_col_from_db(*this, 0, value); + }); + } + + template + void operator>>(std::tuple&& values) { + this->_extract_single_value([&values, this] { + tuple_iterate>::iterate(values, *this); + }); + } + + template + typename std::enable_if::value, void>::type operator>>( + Function&& func) { + typedef utility::function_traits traits; + + this->_extract([&func, this]() { + binder::run(*this, func); + }); + } + }; + + namespace sql_function_binder { + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + } + + enum class OpenFlags { + READONLY = SQLITE_OPEN_READONLY, + READWRITE = SQLITE_OPEN_READWRITE, + CREATE = SQLITE_OPEN_CREATE, + NOMUTEX = SQLITE_OPEN_NOMUTEX, + FULLMUTEX = SQLITE_OPEN_FULLMUTEX, + SHAREDCACHE = SQLITE_OPEN_SHAREDCACHE, + PRIVATECACH = SQLITE_OPEN_PRIVATECACHE, + URI = SQLITE_OPEN_URI + }; + inline OpenFlags operator|(const OpenFlags& a, const OpenFlags& b) { + return static_cast(static_cast(a) | static_cast(b)); + } + enum class Encoding { + ANY = SQLITE_ANY, + UTF8 = SQLITE_UTF8, + UTF16 = SQLITE_UTF16 + }; + struct sqlite_config { + OpenFlags flags = OpenFlags::READWRITE | OpenFlags::CREATE; + const char *zVfs = nullptr; + Encoding encoding = Encoding::ANY; + }; + + class database { + protected: + std::shared_ptr _db; + + public: + database(const std::string &db_name, const sqlite_config &config = {}): _db(nullptr) { + sqlite3* tmp = nullptr; + auto ret = sqlite3_open_v2(db_name.data(), &tmp, static_cast(config.flags), config.zVfs); + _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. + if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret); + sqlite3_extended_result_codes(_db.get(), true); + if(config.encoding == Encoding::UTF16) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(const std::u16string &db_name, const sqlite_config &config = {}): _db(nullptr) { + auto db_name_utf8 = utility::utf16_to_utf8(db_name); + sqlite3* tmp = nullptr; + auto ret = sqlite3_open_v2(db_name_utf8.data(), &tmp, static_cast(config.flags), config.zVfs); + _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. + if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret); + sqlite3_extended_result_codes(_db.get(), true); + if(config.encoding != Encoding::UTF8) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(std::shared_ptr db): + _db(db) {} + + database_binder operator<<(const std::string& sql) { + return database_binder(_db, sql); + } + + database_binder operator<<(const char* sql) { + return *this << std::string(sql); + } + + database_binder operator<<(const std::u16string& sql) { + return database_binder(_db, sql); + } + + database_binder operator<<(const char16_t* sql) { + return *this << std::u16string(sql); + } + + connection_type connection() const { return _db; } + + sqlite3_int64 last_insert_rowid() const { + return sqlite3_last_insert_rowid(_db.get()); + } + + template + void define(const std::string &name, Function&& func) { + typedef utility::function_traits traits; + + auto funcPtr = new auto(std::forward(func)); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity, SQLITE_UTF8, funcPtr, + sql_function_binder::scalar::type>, + nullptr, nullptr, [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result); + } + + template + void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { + typedef utility::function_traits traits; + using ContextType = typename std::remove_reference>::type; + + auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, + sql_function_binder::step::type>, + sql_function_binder::final::type>, + [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result); + } + + }; + + template + class binder { + private: + template < + typename Function, + std::size_t Index + > + using nth_argument_type = typename utility::function_traits< + Function + >::template argument; + + public: + // `Boundary` needs to be defaulted to `Count` so that the `run` function + // template is not implicitly instantiated on class template instantiation. + // Look up section 14.7.1 _Implicit instantiation_ of the ISO C++14 Standard + // and the [dicussion](https://github.com/aminroosta/sqlite_modern_cpp/issues/8) + // on Github. + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) < Boundary), void>::type run( + database_binder& db, + Function&& function, + Values&&... values + ) { + typename std::remove_cv>::type>::type value{}; + get_col_from_db(db, sizeof...(Values), value); + + run(db, function, std::forward(values)..., std::move(value)); + } + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run( + database_binder&, + Function&& function, + Values&&... values + ) { + function(std::move(values)...); + } + }; + + // int + inline database_binder& operator<<(database_binder& db, const int& val) { + int hresult; + if((hresult = sqlite3_bind_int(db._stmt.get(), db._next_index(), val)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + return db; + } + inline void store_result_in_db(sqlite3_context* db, const int& val) { + sqlite3_result_int(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, int& val) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + val = 0; + } else { + val = sqlite3_column_int(db._stmt.get(), inx); + } + } + inline void get_val_from_db(sqlite3_value *value, int& val) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + val = 0; + } else { + val = sqlite3_value_int(value); + } + } + + // sqlite_int64 + inline database_binder& operator <<(database_binder& db, const sqlite_int64& val) { + int hresult; + if((hresult = sqlite3_bind_int64(db._stmt.get(), db._next_index(), val)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { + sqlite3_result_int64(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + i = 0; + } else { + i = sqlite3_column_int64(db._stmt.get(), inx); + } + } + inline void get_val_from_db(sqlite3_value *value, sqlite3_int64& i) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + i = 0; + } else { + i = sqlite3_value_int64(value); + } + } + + // float + inline database_binder& operator <<(database_binder& db, const float& val) { + int hresult; + if((hresult = sqlite3_bind_double(db._stmt.get(), db._next_index(), double(val))) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const float& val) { + sqlite3_result_double(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, float& f) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + f = 0; + } else { + f = float(sqlite3_column_double(db._stmt.get(), inx)); + } + } + inline void get_val_from_db(sqlite3_value *value, float& f) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + f = 0; + } else { + f = float(sqlite3_value_double(value)); + } + } + + // double + inline database_binder& operator <<(database_binder& db, const double& val) { + int hresult; + if((hresult = sqlite3_bind_double(db._stmt.get(), db._next_index(), val)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const double& val) { + sqlite3_result_double(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, double& d) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + d = 0; + } else { + d = sqlite3_column_double(db._stmt.get(), inx); + } + } + inline void get_val_from_db(sqlite3_value *value, double& d) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + d = 0; + } else { + d = sqlite3_value_double(value); + } + } + + // vector + template inline database_binder& operator<<(database_binder& db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + int hresult; + if((hresult = sqlite3_bind_blob(db._stmt.get(), db._next_index(), buf, bytes, SQLITE_TRANSIENT)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + return db; + } + template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); + } + template inline void get_col_from_db(database_binder& db, int inx, std::vector& vec) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + vec.clear(); + } else { + int bytes = sqlite3_column_bytes(db._stmt.get(), inx); + T const* buf = reinterpret_cast(sqlite3_column_blob(db._stmt.get(), inx)); + vec = std::vector(buf, buf + bytes/sizeof(T)); + } + } + template inline void get_val_from_db(sqlite3_value *value, std::vector& vec) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + vec.clear(); + } else { + int bytes = sqlite3_value_bytes(value); + T const* buf = reinterpret_cast(sqlite3_value_blob(value)); + vec = std::vector(buf, buf + bytes/sizeof(T)); + } + } + + /* for nullptr support */ + inline database_binder& operator <<(database_binder& db, std::nullptr_t) { + int hresult; + if((hresult = sqlite3_bind_null(db._stmt.get(), db._next_index())) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + return db; + } + inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { + sqlite3_result_null(db); + } + /* for nullptr support */ + template inline database_binder& operator <<(database_binder& db, const std::unique_ptr& val) { + if(val) + db << *val; + else + db << nullptr; + return db; + } + + /* for unique_ptr support */ + template inline void get_col_from_db(database_binder& db, int inx, std::unique_ptr& _ptr_) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + _ptr_ = nullptr; + } else { + auto underling_ptr = new T(); + get_col_from_db(db, inx, *underling_ptr); + _ptr_.reset(underling_ptr); + } + } + template inline void get_val_from_db(sqlite3_value *value, std::unique_ptr& _ptr_) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + _ptr_ = nullptr; + } else { + auto underling_ptr = new T(); + get_val_from_db(value, *underling_ptr); + _ptr_.reset(underling_ptr); + } + } + + // std::string + inline void get_col_from_db(database_binder& db, int inx, std::string & s) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + s = std::string(); + } else { + sqlite3_column_bytes(db._stmt.get(), inx); + s = std::string(reinterpret_cast(sqlite3_column_text(db._stmt.get(), inx))); + } + } + inline void get_val_from_db(sqlite3_value *value, std::string & s) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + s = std::string(); + } else { + sqlite3_value_bytes(value); + s = std::string(reinterpret_cast(sqlite3_value_text(value))); + } + } + + // Convert char* to string to trigger op<<(..., const std::string ) + template inline database_binder& operator <<(database_binder& db, const char(&STR)[N]) { return db << std::string(STR); } + template inline database_binder& operator <<(database_binder& db, const char16_t(&STR)[N]) { return db << std::u16string(STR); } + + inline database_binder& operator <<(database_binder& db, const std::string& txt) { + int hresult; + if((hresult = sqlite3_bind_text(db._stmt.get(), db._next_index(), txt.data(), -1, SQLITE_TRANSIENT)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::string& val) { + sqlite3_result_text(db, val.data(), -1, SQLITE_TRANSIENT); + } + // std::u16string + inline void get_col_from_db(database_binder& db, int inx, std::u16string & w) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + w = std::u16string(); + } else { + sqlite3_column_bytes16(db._stmt.get(), inx); + w = std::u16string(reinterpret_cast(sqlite3_column_text16(db._stmt.get(), inx))); + } + } + inline void get_val_from_db(sqlite3_value *value, std::u16string & w) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + w = std::u16string(); + } else { + sqlite3_value_bytes16(value); + w = std::u16string(reinterpret_cast(sqlite3_value_text16(value))); + } + } + + + inline database_binder& operator <<(database_binder& db, const std::u16string& txt) { + int hresult; + if((hresult = sqlite3_bind_text16(db._stmt.get(), db._next_index(), txt.data(), -1, SQLITE_TRANSIENT)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::u16string& val) { + sqlite3_result_text16(db, val.data(), -1, SQLITE_TRANSIENT); + } + + // Other integer types + template::value>::type> + inline database_binder& operator <<(database_binder& db, const Integral& val) { + return db << static_cast(val); + } + template::type>> + inline void store_result_in_db(sqlite3_context* db, const Integral& val) { + store_result_in_db(db, static_cast(val)); + } + template::value>::type> + inline void get_col_from_db(database_binder& db, int inx, Integral& val) { + sqlite3_int64 i; + get_col_from_db(db, inx, i); + val = i; + } + template::value>::type> + inline void get_val_from_db(sqlite3_value *value, Integral& val) { + sqlite3_int64 i; + get_val_from_db(value, i); + val = i; + } + + // std::optional support for NULL values +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template inline database_binder& operator <<(database_binder& db, const optional& val) { + if(val) { + return db << std::move(*val); + } else { + return db << nullptr; + } + } + template inline void store_result_in_db(sqlite3_context* db, const optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } + + template inline void get_col_from_db(database_binder& db, int inx, optional& o) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + o = std::experimental::nullopt; + #else + o.reset(); + #endif + } else { + OptionalT v; + get_col_from_db(db, inx, v); + o = std::move(v); + } + } + template inline void get_val_from_db(sqlite3_value *value, optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + o = std::experimental::nullopt; + #else + o.reset(); + #endif + } else { + OptionalT v; + get_val_from_db(value, v); + o = std::move(v); + } + } +#endif + + // boost::optional support for NULL values +#ifdef _MODERN_SQLITE_BOOST_OPTIONAL_SUPPORT + template inline database_binder& operator <<(database_binder& db, const boost::optional& val) { + if(val) { + return db << std::move(*val); + } else { + return db << nullptr; + } + } + template inline void store_result_in_db(sqlite3_context* db, const boost::optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } + + template inline void get_col_from_db(database_binder& db, int inx, boost::optional& o) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + o.reset(); + } else { + BoostOptionalT v; + get_col_from_db(db, inx, v); + o = std::move(v); + } + } + template inline void get_val_from_db(sqlite3_value *value, boost::optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + o.reset(); + } else { + BoostOptionalT v; + get_val_from_db(value, v); + o = std::move(v); + } + } +#endif + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template inline database_binder& operator <<(database_binder& db, const std::variant& val) { + std::visit([&](auto &&opt) {db << std::forward(opt);}, val); + return db; + } + template inline void store_result_in_db(sqlite3_context* db, const std::variant& val) { + std::visit([&](auto &&opt) {store_result_in_db(db, std::forward(opt));}, val); + } + template inline void get_col_from_db(database_binder& db, int inx, std::variant& val) { + utility::variant_select(sqlite3_column_type(db._stmt.get(), inx))([&](auto v) { + get_col_from_db(db, inx, v); + val = std::move(v); + }); + } + template inline void get_val_from_db(sqlite3_value *value, std::variant& val) { + utility::variant_select(sqlite3_value_type(value))([&](auto v) { + get_val_from_db(value, v); + val = std::move(v); + }); + } +#endif + + // Some ppl are lazy so we have a operator for proper prep. statemant handling. + void inline operator++(database_binder& db, int) { db.execute(); } + + // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) + template database_binder&& operator << (database_binder&& db, const T& val) { db << val; return std::move(db); } + + namespace sql_function_binder { + template + struct AggregateCtxt { + T obj; + bool constructed = true; + }; + + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + try { + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + step(db, count, vals, ctxt->obj); + return; + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits< + typename Functions::first_type + >::template argument + >::type + >::type value{}; + get_val_from_db(vals[sizeof...(Values) - 1], value); + + step(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + static_cast(sqlite3_user_data(db))->first(std::forward(values)...); + } + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + try { + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + store_result_in_db(db, + static_cast(sqlite3_user_data(db))->second(ctxt->obj)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits::template argument + >::type + >::type value{}; + get_val_from_db(vals[sizeof...(Values)], value); + + scalar(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + try { + store_result_in_db(db, + (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/errors.h b/sqlite_modern_cpp/sqlite_modern_cpp/errors.h new file mode 100644 index 0000000..2b9ab75 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/errors.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include + +namespace sqlite { + + class sqlite_exception: public std::runtime_error { + public: + sqlite_exception(const char* msg, std::string sql, int code = -1): runtime_error(msg), code(code), sql(sql) {} + sqlite_exception(int code, std::string sql): runtime_error(sqlite3_errstr(code)), code(code), sql(sql) {} + int get_code() const {return code & 0xFF;} + int get_extended_code() const {return code;} + std::string get_sql() const {return sql;} + private: + int code; + std::string sql; + }; + + namespace errors { + //One more or less trivial derived error class for each SQLITE error. + //Note the following are not errors so have no classes: + //SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE + // + //Note these names are exact matches to the names of the SQLITE error codes. +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + class name: public sqlite_exception { using sqlite_exception::sqlite_exception; };\ + derived +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + class base ## _ ## sub: public base { using base::base; }; +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + + //Some additional errors are here for the C++ interface + class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement + class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + + static void throw_sqlite_error(const int& error_code, const std::string &sql = "") { + switch(error_code & 0xFF) { +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + case SQLITE_ ## NAME: switch(error_code) { \ + derived \ + default: throw name(error_code, sql); \ + } +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql); +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + default: throw sqlite_exception(error_code, sql); + } + } + } + namespace exceptions = errors; +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h b/sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h new file mode 100644 index 0000000..5dfa0d3 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h @@ -0,0 +1,93 @@ +#if SQLITE_VERSION_NUMBER < 3010000 +#define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8)) +#define SQLITE_IOERR_AUTH (SQLITE_IOERR | (28<<8)) +#define SQLITE_AUTH_USER (SQLITE_AUTH | (1<<8)) +#endif +SQLITE_MODERN_CPP_ERROR_CODE(ERROR,error,) +SQLITE_MODERN_CPP_ERROR_CODE(INTERNAL,internal,) +SQLITE_MODERN_CPP_ERROR_CODE(PERM,perm,) +SQLITE_MODERN_CPP_ERROR_CODE(ABORT,abort, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(ABORT,ROLLBACK,abort,rollback) +) +SQLITE_MODERN_CPP_ERROR_CODE(BUSY,busy, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,RECOVERY,busy,recovery) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,SNAPSHOT,busy,snapshot) +) +SQLITE_MODERN_CPP_ERROR_CODE(LOCKED,locked, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(LOCKED,SHAREDCACHE,locked,sharedcache) +) +SQLITE_MODERN_CPP_ERROR_CODE(NOMEM,nomem,) +SQLITE_MODERN_CPP_ERROR_CODE(READONLY,readonly,) +SQLITE_MODERN_CPP_ERROR_CODE(INTERRUPT,interrupt,) +SQLITE_MODERN_CPP_ERROR_CODE(IOERR,ioerr, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,READ,ioerr,read) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHORT_READ,ioerr,short_read) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,WRITE,ioerr,write) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSYNC,ioerr,fsync) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_FSYNC,ioerr,dir_fsync) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,TRUNCATE,ioerr,truncate) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSTAT,ioerr,fstat) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,UNLOCK,ioerr,unlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,RDLOCK,ioerr,rdlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE,ioerr,delete) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,BLOCKED,ioerr,blocked) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,NOMEM,ioerr,nomem) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,ACCESS,ioerr,access) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CHECKRESERVEDLOCK,ioerr,checkreservedlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,LOCK,ioerr,lock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CLOSE,ioerr,close) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_CLOSE,ioerr,dir_close) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMOPEN,ioerr,shmopen) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMSIZE,ioerr,shmsize) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMLOCK,ioerr,shmlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMMAP,ioerr,shmmap) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SEEK,ioerr,seek) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE_NOENT,ioerr,delete_noent) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,MMAP,ioerr,mmap) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,GETTEMPPATH,ioerr,gettemppath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CONVPATH,ioerr,convpath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,VNODE,ioerr,vnode) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,AUTH,ioerr,auth) +) +SQLITE_MODERN_CPP_ERROR_CODE(CORRUPT,corrupt, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CORRUPT,VTAB,corrupt,vtab) +) +SQLITE_MODERN_CPP_ERROR_CODE(NOTFOUND,notfound,) +SQLITE_MODERN_CPP_ERROR_CODE(FULL,full,) +SQLITE_MODERN_CPP_ERROR_CODE(CANTOPEN,cantopen, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,NOTEMPDIR,cantopen,notempdir) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,ISDIR,cantopen,isdir) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,FULLPATH,cantopen,fullpath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,CONVPATH,cantopen,convpath) +) +SQLITE_MODERN_CPP_ERROR_CODE(PROTOCOL,protocol,) +SQLITE_MODERN_CPP_ERROR_CODE(EMPTY,empty,) +SQLITE_MODERN_CPP_ERROR_CODE(SCHEMA,schema,) +SQLITE_MODERN_CPP_ERROR_CODE(TOOBIG,toobig,) +SQLITE_MODERN_CPP_ERROR_CODE(CONSTRAINT,constraint, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,CHECK,constraint,check) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,COMMITHOOK,constraint,commithook) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FOREIGNKEY,constraint,foreignkey) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FUNCTION,constraint,function) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,NOTNULL,constraint,notnull) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,PRIMARYKEY,constraint,primarykey) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,TRIGGER,constraint,trigger) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,UNIQUE,constraint,unique) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,VTAB,constraint,vtab) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,ROWID,constraint,rowid) +) +SQLITE_MODERN_CPP_ERROR_CODE(MISMATCH,mismatch,) +SQLITE_MODERN_CPP_ERROR_CODE(MISUSE,misuse,) +SQLITE_MODERN_CPP_ERROR_CODE(NOLFS,nolfs,) +SQLITE_MODERN_CPP_ERROR_CODE(AUTH,auth, +) +SQLITE_MODERN_CPP_ERROR_CODE(FORMAT,format,) +SQLITE_MODERN_CPP_ERROR_CODE(RANGE,range,) +SQLITE_MODERN_CPP_ERROR_CODE(NOTADB,notadb,) +SQLITE_MODERN_CPP_ERROR_CODE(NOTICE,notice, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_WAL,notice,recover_wal) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_ROLLBACK,notice,recover_rollback) +) +SQLITE_MODERN_CPP_ERROR_CODE(WARNING,warning, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(WARNING,AUTOINDEX,warning,autoindex) +) diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/log.h b/sqlite_modern_cpp/sqlite_modern_cpp/log.h new file mode 100644 index 0000000..a8f7be2 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/log.h @@ -0,0 +1,101 @@ +#include "errors.h" + +#include + +#include +#include +#include + +namespace sqlite { + namespace detail { + template + using void_t = void; + template + struct is_callable : std::false_type {}; + template + struct is_callable()(std::declval()...))>> : std::true_type {}; + template + class FunctorOverload: public Functor, public FunctorOverload { + public: + template + FunctorOverload(Functor1 &&functor, Remaining &&... remaining): + Functor(std::forward(functor)), + FunctorOverload(std::forward(remaining)...) {} + using Functor::operator(); + using FunctorOverload::operator(); + }; + template + class FunctorOverload: public Functor { + public: + template + FunctorOverload(Functor1 &&functor): + Functor(std::forward(functor)) {} + using Functor::operator(); + }; + template + class WrapIntoFunctor: public Functor { + public: + template + WrapIntoFunctor(Functor1 &&functor): + Functor(std::forward(functor)) {} + using Functor::operator(); + }; + template + class WrapIntoFunctor { + ReturnType(*ptr)(Arguments...); + public: + WrapIntoFunctor(ReturnType(*ptr)(Arguments...)): ptr(ptr) {} + ReturnType operator()(Arguments... arguments) { return (*ptr)(std::forward(arguments)...); } + }; + inline void store_error_log_data_pointer(std::shared_ptr ptr) { + static std::shared_ptr stored; + stored = std::move(ptr); + } + template + std::shared_ptr::type> make_shared_inferred(T &&t) { + return std::make_shared::type>(std::forward(t)); + } + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler); + template + typename std::enable_if::value>::type + error_log(Handler &&handler); + template + typename std::enable_if=2>::type + error_log(Handler &&...handler) { + return error_log(detail::FunctorOverload::type>...>(std::forward(handler)...)); + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler) { + return error_log(std::forward(handler), [](const sqlite_exception&) {}); + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler) { + auto ptr = detail::make_shared_inferred([handler = std::forward(handler)](int error_code, const char *errstr) mutable { + switch(error_code & 0xFF) { +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + case SQLITE_ ## NAME: switch(error_code) { \ + derived \ + default: handler(errors::name(errstr, "", error_code)); \ + };break; +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + case SQLITE_ ## BASE ## _ ## SUB: \ + handler(errors::base ## _ ## sub(errstr, "", error_code)); \ + break; +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + default: handler(sqlite_exception(errstr, "", error_code)); \ + } + }); + + sqlite3_config(SQLITE_CONFIG_LOG, (void(*)(void*,int,const char*))[](void *functor, int error_code, const char *errstr) { + (*static_cast(functor))(error_code, errstr); + }, ptr.get()); + detail::store_error_log_data_pointer(std::move(ptr)); + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h b/sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h new file mode 100644 index 0000000..da0f018 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h @@ -0,0 +1,44 @@ +#pragma once + +#ifndef SQLITE_HAS_CODEC +#define SQLITE_HAS_CODEC +#endif + +#include "../sqlite_modern_cpp.h" + +namespace sqlite { + struct sqlcipher_config : public sqlite_config { + std::string key; + }; + + class sqlcipher_database : public database { + public: + sqlcipher_database(std::string db, const sqlcipher_config &config): database(db, config) { + set_key(config.key); + } + + sqlcipher_database(std::u16string db, const sqlcipher_config &config): database(db, config) { + set_key(config.key); + } + + void set_key(const std::string &key) { + if(auto ret = sqlite3_key(_db.get(), key.data(), key.size())) + errors::throw_sqlite_error(ret); + } + + void set_key(const std::string &key, const std::string &db_name) { + if(auto ret = sqlite3_key_v2(_db.get(), db_name.c_str(), key.data(), key.size())) + errors::throw_sqlite_error(ret); + } + + void rekey(const std::string &new_key) { + if(auto ret = sqlite3_rekey(_db.get(), new_key.data(), new_key.size())) + errors::throw_sqlite_error(ret); + } + + void rekey(const std::string &new_key, const std::string &db_name) { + if(auto ret = sqlite3_rekey_v2(_db.get(), db_name.c_str(), new_key.data(), new_key.size())) + errors::throw_sqlite_error(ret); + } + }; +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/function_traits.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/function_traits.h new file mode 100644 index 0000000..cd8fab0 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/function_traits.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +namespace sqlite { + namespace utility { + + template struct function_traits; + + template + struct function_traits : public function_traits< + decltype(&std::remove_reference::type::operator()) + > { }; + + template < + typename ClassType, + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(ClassType::*)(Arguments...) const + > : function_traits { }; + + /* support the non-const operator () + * this will work with user defined functors */ + template < + typename ClassType, + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(ClassType::*)(Arguments...) + > : function_traits { }; + + template < + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(*)(Arguments...) + > { + typedef ReturnType result_type; + + template + using argument = typename std::tuple_element< + Index, + std::tuple + >::type; + + static const std::size_t arity = sizeof...(Arguments); + }; + + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/uncaught_exceptions.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/uncaught_exceptions.h new file mode 100644 index 0000000..17d6326 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/uncaught_exceptions.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include + +namespace sqlite { + namespace utility { +#ifdef __cpp_lib_uncaught_exceptions + class UncaughtExceptionDetector { + public: + operator bool() { + return count != std::uncaught_exceptions(); + } + private: + int count = std::uncaught_exceptions(); + }; +#else + class UncaughtExceptionDetector { + public: + operator bool() { + return std::uncaught_exception(); + } + }; +#endif + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h new file mode 100644 index 0000000..ea21723 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include "../errors.h" + +namespace sqlite { + namespace utility { + inline std::string utf16_to_utf8(const std::u16string &input) { + struct : std::codecvt { + } codecvt; + std::mbstate_t state{}; + std::string result((std::max)(input.size() * 3 / 2, std::size_t(4)), '\0'); + const char16_t *remaining_input = input.data(); + std::size_t produced_output = 0; + while(true) { + char *used_output; + switch(codecvt.out(state, remaining_input, &input[input.size()], + remaining_input, &result[produced_output], + &result[result.size() - 1] + 1, used_output)) { + case std::codecvt_base::ok: + result.resize(used_output - result.data()); + return result; + case std::codecvt_base::noconv: + // This should be unreachable + case std::codecvt_base::error: + throw errors::invalid_utf16("Invalid UTF-16 input", ""); + case std::codecvt_base::partial: + if(used_output == result.data() + produced_output) + throw errors::invalid_utf16("Unexpected end of input", ""); + produced_output = used_output - result.data(); + result.resize( + result.size() + + (std::max)((&input[input.size()] - remaining_input) * 3 / 2, + std::ptrdiff_t(4))); + } + } + } + } // namespace utility +} // namespace sqlite diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h new file mode 100644 index 0000000..11a8429 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h @@ -0,0 +1,201 @@ +#pragma once + +#include "../errors.h" +#include +#include +#include + +namespace sqlite::utility { + template + struct VariantFirstNullable { + using type = void; + }; + template + struct VariantFirstNullable { + using type = typename VariantFirstNullable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstNullable, Options...> { + using type = std::optional; + }; +#endif + template + struct VariantFirstNullable, Options...> { + using type = std::unique_ptr; + }; + template + struct VariantFirstNullable { + using type = std::nullptr_t; + }; + template + inline void variant_select_null(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("NULL is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstNullable::type()); + } + } + + template + struct VariantFirstIntegerable { + using type = void; + }; + template + struct VariantFirstIntegerable { + using type = typename VariantFirstIntegerable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstIntegerable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstIntegerable::type>; + }; +#endif + template + struct VariantFirstIntegerable::type, T>>, std::unique_ptr, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstIntegerable::type>; + }; + template + struct VariantFirstIntegerable { + using type = int; + }; + template + struct VariantFirstIntegerable { + using type = sqlite_int64; + }; + template + inline auto variant_select_integer(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Integer is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstIntegerable::type()); + } + } + + template + struct VariantFirstFloatable { + using type = void; + }; + template + struct VariantFirstFloatable { + using type = typename VariantFirstFloatable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstFloatable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstFloatable::type>; + }; +#endif + template + struct VariantFirstFloatable, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstFloatable::type>; + }; + template + struct VariantFirstFloatable { + using type = float; + }; + template + struct VariantFirstFloatable { + using type = double; + }; + template + inline auto variant_select_float(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Real is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstFloatable::type()); + } + } + + template + struct VariantFirstTextable { + using type = void; + }; + template + struct VariantFirstTextable { + using type = typename VariantFirstTextable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstTextable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstTextable::type>; + }; +#endif + template + struct VariantFirstTextable, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstTextable::type>; + }; + template + struct VariantFirstTextable { + using type = std::string; + }; + template + struct VariantFirstTextable { + using type = std::u16string; + }; + template + inline void variant_select_text(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Text is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstTextable::type()); + } + } + + template + struct VariantFirstBlobable { + using type = void; + }; + template + struct VariantFirstBlobable { + using type = typename VariantFirstBlobable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstBlobable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstBlobable::type>; + }; +#endif + template + struct VariantFirstBlobable, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstBlobable::type>; + }; + template + struct VariantFirstBlobable>, std::vector, Options...> { + using type = std::vector; + }; + template + inline auto variant_select_blob(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Blob is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstBlobable::type()); + } + } + + template + inline auto variant_select(int type) { + return [type](auto &&callback) { + using Callback = decltype(callback); + switch(type) { + case SQLITE_NULL: + variant_select_null(std::forward(callback)); + break; + case SQLITE_INTEGER: + variant_select_integer(std::forward(callback)); + break; + case SQLITE_FLOAT: + variant_select_float(std::forward(callback)); + break; + case SQLITE_TEXT: + variant_select_text(std::forward(callback)); + break; + case SQLITE_BLOB: + variant_select_blob(std::forward(callback)); + break; + default:; + /* assert(false); */ + } + }; + } +}