diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b1823c..c008211 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ add_subdirectory(thread-pool) add_subdirectory(fmt) add_executable(discord_llama main.cpp) -target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool) +target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool sqlite3) install(TARGETS discord_llama LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/anyproc b/anyproc index 6ba351b..03b1689 160000 --- a/anyproc +++ b/anyproc @@ -1 +1 @@ -Subproject commit 6ba351b1173539254b2e641d19ca77c459a087a2 +Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707 diff --git a/example_config.txt b/example_config.txt index f110d0a..7f9a87a 100644 --- a/example_config.txt +++ b/example_config.txt @@ -1,12 +1,18 @@ token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg # The following parameters are set to their defaults here and can be ommited +models_dir models language EN -inference_model 13B-ggml-model-quant.bin -translation_model 13B-ggml-model-quant.bin -prompt_file prompt.txt +threads_only true + +default_inference_model 13B-vanilla +translation_model none + +prompt_file none +instruct_prompt_file none + +persistance true mlock false pool_size 2 threads 4 -persistance true ctx_size 1012 diff --git a/example_models/13B-vanilla.txt b/example_models/13B-vanilla.txt new file mode 100644 index 0000000..537f3e0 --- /dev/null +++ b/example_models/13B-vanilla.txt @@ -0,0 +1,2 @@ +filename 13B-ggml-model-quant.bin +instruct_mode_policy forbid diff --git a/example_models/13B-vicuna-1.0.txt b/example_models/13B-vicuna-1.0.txt new file mode 100644 index 0000000..05f6516 --- /dev/null +++ b/example_models/13B-vicuna-1.0.txt @@ -0,0 +1,4 @@ +filename ggml-vicuna-13b-1.0-q4_0.bin +instruct_mode_policy force +user_prompt ### Human: +bot_prompt ### Assistant: diff --git a/example_models/13B-vicuna-1.1.txt b/example_models/13B-vicuna-1.1.txt new file mode 100644 index 0000000..6ebbb53 --- /dev/null +++ b/example_models/13B-vicuna-1.1.txt @@ -0,0 +1,4 @@ +filename ggml-vicuna-13b-1.1-q4_0.bin +instruct_mode_policy force +user_prompt HUMAN: +bot_prompt ASSISTANT: diff --git a/example_models/gpt4all-unfiltered.txt b/example_models/gpt4all-unfiltered.txt new file mode 100644 index 0000000..aafcd13 --- /dev/null +++ b/example_models/gpt4all-unfiltered.txt @@ -0,0 +1,4 @@ +filename gpt4all-lora-unfiltered-quantized.bin +instruct_mode_policy allow +user_prompt ### Instruction: +bot_prompt ### Response: diff --git a/main.cpp b/main.cpp index f0aff9c..45437ab 100644 --- a/main.cpp +++ b/main.cpp @@ -1,4 +1,5 @@ #include "Timer.hpp" +#include "sqlite_modern_cpp/sqlite_modern_cpp.h" #include #include @@ -48,9 +49,23 @@ void str_replace_in_place(std::string& subject, std::string_view search, } } +static +void clean_command_name(std::string& value) { + for (auto& c : value) { + if (c == '.') c = '_'; + if (isalpha(c)) c = tolower(c); + } +} +[[nodiscard]] static +std::string clean_command_name(std::string_view input) { + std::string fres(input); + clean_command_name(fres); + return fres; +} + class Bot { - ThreadPool tPool{1}; + ThreadPool thread_pool{1}; Timer last_message_timer; std::shared_ptr stopping; LM::InferencePool llm_pool; @@ -58,47 +73,64 @@ class Bot { std::vector my_messages; std::unordered_map users; std::thread::id llm_tid; + sqlite::database db; std::string_view language; dpp::cluster bot; +public: + struct ModelConfig { + std::string weight_path, + user_prompt, + bot_prompt; + enum class InstructModePolicy { + Allow = 0b11, + Force = 0b10, + Forbid = 0b01 + } instruct_mode_policy = InstructModePolicy::Allow; + + bool is_instruct_mode_allowed() const { + return static_cast(instruct_mode_policy) & 0b10; + } + bool is_non_instruct_mode_allowed() const { + return static_cast(instruct_mode_policy) & 0b01; + } + }; + struct BotChannelConfig { + const std::string *model_name; + const ModelConfig *model_config; + bool instruct_mode = false; + }; + struct Configuration { + std::string token, + language = "EN", + default_inference_model = "13B-vanilla", + translation_model = "none", + prompt_file = "none", + instruct_prompt_file = "none", + models_dir = "models"; + unsigned ctx_size = 1012, + pool_size = 2, + threads = 4, + persistance = true; + bool mlock = false, + threads_only = true; + const ModelConfig *default_inference_model_cfg = nullptr, + *translation_model_cfg = nullptr; + }; + +private: + const Configuration& config; + const std::unordered_map& model_configs; + struct Texts { std::string please_wait = "Please wait...", - loading = "Loading...", - initializing = "Initializing...", + thread_create_fail = "Error: I couldn't create a thread here. Do I have enough permissions?", + model_missing = "Error: The model that was used in this thread could no longer be found.", timeout = "Error: Timeout"; bool translated = false; } texts; - inline static - std::string create_text_progress_indicator(uint8_t percentage) { - static constexpr uint8_t divisor = 3, - width = 100 / divisor; - // Progress bar percentage lookup - const static auto indicator_lookup = [] () consteval { - std::array fres; - for (uint8_t it = 0; it != 101; it++) { - fres[it] = it / divisor; - } - return fres; - }(); - // Initialize string - std::string fres; - fres.resize(width+4); - fres[0] = '`'; - fres[1] = '['; - // Append progress - const uint8_t bars = indicator_lookup[percentage]; - for (uint8_t it = 0; it != width; it++) { - if (it < bars) fres[it+2] = '#'; - else fres[it+2] = ' '; - } - // Finalize and return string - fres[width+2] = ']'; - fres[width+3] = '`'; - return fres; - } - inline static bool show_console_progress(float progress) { std::cout << ' ' << unsigned(progress) << "% \r" << std::flush; @@ -178,27 +210,29 @@ class Bot { } // Must run in llama thread - void llm_restart(LM::Inference& inference) { - // Deserialize init cache - std::ifstream f("init_cache", std::ios::binary); + void llm_restart(LM::Inference& inference, const BotChannelConfig& channel_cfg) { + ENSURE_LLM_THREAD(); + // Deserialize init cache if not instruct mode without prompt file + if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return; + std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary); inference.deserialize(f); } // Must run in llama thread - LM::Inference &llm_restart(dpp::snowflake id) { + LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get or create inference - auto& inference = llm_pool.get_or_create_inference(id, config.inference_model, llm_get_params()); - llm_restart(inference); + auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params()); + llm_restart(inference, channel_cfg); return inference; } // Must run in llama thread - LM::Inference &llm_get_inference(dpp::snowflake id) { + LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); auto inference_opt = llm_pool.get_inference(id); if (!inference_opt.has_value()) { // Start new inference - inference_opt = llm_restart(id); + inference_opt = llm_restart(id, channel_cfg); } return inference_opt.value(); } @@ -210,54 +244,97 @@ class Bot { // Translate texts if (!texts.translated) { texts.please_wait = llm_translate_from_en(texts.please_wait); - texts.initializing = llm_translate_from_en(texts.initializing); - texts.loading = llm_translate_from_en(texts.loading); + texts.model_missing = llm_translate_from_en(texts.model_missing); + texts.thread_create_fail = llm_translate_from_en(texts.thread_create_fail); texts.timeout = llm_translate_from_en(texts.timeout); texts.translated = true; } - // Inference for init cache - if (!std::filesystem::exists("init_cache")) { - LM::Inference llm(config.inference_model, llm_get_params()); - std::ofstream f("init_cache", std::ios::binary); - // Add initial context - std::string prompt; - { - // Read whole file - std::ifstream f(config.prompt_file); - if (!f) { - // Clean up and abort on error - std::cerr << "Failed to open prompt file." << std::endl; - f.close(); - std::error_code ec; - std::filesystem::remove("init_cache", ec); - abort(); + // Build init caches + std::string filename; + for (const auto& [model_name, model_config] : model_configs) { + //TODO: Add hashes to regenerate these as needed + // Standard prompt + filename = model_name+"_init_cache"; + if (model_config.is_non_instruct_mode_allowed() && + !std::filesystem::exists(filename) && config.prompt_file != "none") { + std::cout << "Building init_cache for "+model_name+"..." << std::endl; + LM::Inference llm(model_config.weight_path, llm_get_params()); + // Add initial context + std::string prompt; + { + // Read whole file + std::ifstream f(config.prompt_file); + if (!f) { + // Clean up and abort on error + std::cerr << "Failed to open prompt file." << std::endl; + abort(); + } + std::ostringstream sstr; + sstr << f.rdbuf(); + prompt = sstr.str(); } - std::ostringstream sstr; - sstr << f.rdbuf(); - prompt = sstr.str(); + // Append + using namespace fmt::literals; + if (prompt.back() != '\n') prompt.push_back('\n'); + llm.append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress); + // Serialize end result + std::ofstream f(filename, std::ios::binary); + llm.serialize(f); + } + // Instruct prompt + filename = model_name+"_instruct_init_cache"; + if (model_config.is_instruct_mode_allowed() && + !std::filesystem::exists(filename) && config.instruct_prompt_file != "none") { + std::cout << "Building instruct_init_cache for "+model_name+"..." << std::endl; + LM::Inference llm(model_config.weight_path, llm_get_params()); + // Add initial context + std::string prompt; + { + // Read whole file + std::ifstream f(config.instruct_prompt_file); + if (!f) { + // Clean up and abort on error + std::cerr << "Failed to open prompt file." << std::endl; + abort(); + } + std::ostringstream sstr; + sstr << f.rdbuf(); + prompt = sstr.str(); + } + // Append + using namespace fmt::literals; + if (prompt.back() != '\n') prompt.push_back('\n'); + llm.append(fmt::format(fmt::runtime(prompt+'\n'), "bot_name"_a=bot.me.username), show_console_progress); + // Serialize end result + std::ofstream f(filename, std::ios::binary); + llm.serialize(f); } - // Append - using namespace fmt::literals; - llm.append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)); - // Serialize end result - llm.serialize(f); } + // Report complete init + std::cout << "Init done!" << std::endl; } // Must run in llama thread - void prompt_add_msg(const dpp::message& msg) { + void prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Make sure message isn't too long if (msg.content.size() > 512) { return; } // Get inference - auto& inference = llm_get_inference(msg.channel_id); + auto& inference = llm_get_inference(msg.channel_id, channel_cfg); try { + std::string prefix; + // Instruct mode user prompt + if (channel_cfg.instruct_mode) { + inference.append('\n'+channel_cfg.model_config->user_prompt+"\n\n"); + } else { + prefix = msg.author.username+": "; + } // Format and append lines for (const auto line : str_split(msg.content, '\n')) { Timer timeout; bool timeout_exceeded = false; - inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) { + inference.append(prefix+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) { if (timeout.get() > 1) { std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl; timeout_exceeded = true; @@ -268,28 +345,33 @@ class Bot { if (timeout_exceeded) inference.append("\n"); } } catch (const LM::Inference::ContextLengthException&) { - llm_restart(inference); - prompt_add_msg(msg); + llm_restart(inference, channel_cfg); + prompt_add_msg(msg, channel_cfg); } } // Must run in llama thread - void prompt_add_trigger(dpp::snowflake id) { + void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); - auto& inference = llm_get_inference(id); + auto& inference = llm_get_inference(id, channel_cfg); try { - inference.append(bot.me.username+':', show_console_progress); + if (channel_cfg.instruct_mode) { + inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n"); + } else { + inference.append(bot.me.username+':', show_console_progress); + } } catch (const LM::Inference::ContextLengthException&) { - llm_restart(inference); + llm_restart(inference, channel_cfg); } } // Must run in llama thread - void reply(dpp::snowflake id, dpp::message msg) { + void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); - try { // Trigger LLM correctly - prompt_add_trigger(id); + try { + // Trigger LLM correctly + prompt_add_trigger(id, channel_cfg); // Get inference - auto& inference = llm_get_inference(id); + auto& inference = llm_get_inference(id, channel_cfg); // Run model Timer timeout; bool timeout_exceeded = false; @@ -315,16 +397,16 @@ class Bot { } } - bool attempt_reply(const dpp::message& msg) { + bool attempt_reply(const dpp::message& msg, const BotChannelConfig& channel_cfg) { // Reply if message contains username, mention or ID if (msg.content.find(bot.me.username) != std::string::npos) { - enqueue_reply(msg.channel_id); + enqueue_reply(msg.channel_id, channel_cfg); return true; } // Reply if message references user for (const auto msg_id : my_messages) { if (msg.message_reference.message_id == msg_id) { - enqueue_reply(msg.channel_id); + enqueue_reply(msg.channel_id, channel_cfg); return true; } } @@ -332,45 +414,40 @@ class Bot { return false; } - void enqueue_reply(dpp::snowflake id) { - bot.message_create(dpp::message(id, texts.please_wait+" :thinking:"), [this, id] (const dpp::confirmation_callback_t& ccb) { + void enqueue_reply(dpp::snowflake id, const BotChannelConfig& channel_cfg) { + bot.message_create(dpp::message(id, texts.please_wait+" :thinking:"), [=, this] (const dpp::confirmation_callback_t& ccb) { if (ccb.is_error()) return; - tPool.submit(std::bind(&Bot::reply, this, id, ccb.get())); + thread_pool.submit(std::bind(&Bot::reply, this, id, ccb.get(), channel_cfg)); }); } public: - struct Configuration { - std::string token, - language = "EN", - inference_model = "13B-ggml-model-quant.bin", - translation_model = "13B-ggml-model-quant.bin", - prompt_file = "prompt.txt"; - unsigned ctx_size = 1012, - pool_size = 2, - threads = 4, - persistance = true; - bool mlock = false; - } config; + Bot(decltype(config) cfg, decltype(model_configs) model_configs) + : config(cfg), model_configs(model_configs), bot(cfg.token), + language(cfg.language), db("database.sqlite3"), + llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) { + // Initialize database + db << "CREATE TABLE IF NOT EXISTS threads (" + " id TEXT PRIMARY KEY NOT NULL," + " model TEXT," + " instruct_mode INTEGER," + " UNIQUE(id)" + ");"; - Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language), - llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) { // Configure llm_pool llm_pool.set_store_on_destruct(cfg.persistance); // Initialize thread pool - tPool.init(); + thread_pool.init(); // Prepare translator if (language != "EN") { - tPool.submit([this] () { + thread_pool.submit([this] () { + std::cout << "Preparing translator..." << std::endl; translator = std::make_unique(config.translation_model, llm_get_translation_params()); }); } - // Prepare llm - tPool.submit(std::bind(&Bot::llm_init, this)); - // Configure bot bot.on_log(dpp::utility::cout_logger()); bot.intents = dpp::i_guild_messages | dpp::i_message_content; @@ -378,6 +455,59 @@ public: // Set callbacks bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal std::cout << "Connected to Discord." << std::endl; + // Register chat command once + if (dpp::run_once()) { + for (const auto& [name, model] : model_configs) { + // Create command + dpp::slashcommand command(name, "Start a chat with me", bot.me.id); + // Add instruct mode option + if (model.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) { + command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Weather to enable instruct mode", true)); + } + // Register command + bot.global_command_edit(command, [this, command] (const dpp::confirmation_callback_t& ccb) { + if (ccb.is_error()) bot.global_command_create(command); + }); + } + } + if (dpp::run_once()) { + // Prepare llm + thread_pool.submit(std::bind(&Bot::llm_init, this)); + } + }); + bot.on_slashcommand([=, this](const dpp::slashcommand_t& event) { + // Get model by name + auto res = model_configs.find(event.command.get_command_name()); + if (res == model_configs.end()) { + // Model does not exit, delete corresponding command + bot.global_command_delete(event.command.id); + return; + } + const auto& [model_name, model_config] = *res; + // Get weather to enable instruct mode + bool instruct_mode; + if (model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) { + instruct_mode = std::get(event.get_parameter("instruct_mode")); + } else { + instruct_mode = model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Force; + } + // Create thread + bot.thread_create("Chat with "+model_name, event.command.channel_id, 1440, dpp::CHANNEL_PUBLIC_THREAD, true, 15, + [this, event, instruct_mode, model_name = res->first] (const dpp::confirmation_callback_t& ccb) { + // Check for error + if (ccb.is_error()) { + std::cout << "Thread creation failed: " << ccb.get_error().message << std::endl; + event.reply(dpp::message(texts.thread_create_fail).set_flags(dpp::message_flags::m_ephemeral)); + return; + } + // Get thread + const auto& thread = ccb.get(); + // Add thread to database + db << "INSERT INTO threads (id, model, instruct_mode) VALUES (?, ?, ?);" + << std::to_string(thread.id) << model_name << instruct_mode; + // Report success + event.reply(dpp::message("Okay!").set_flags(dpp::message_flags::m_ephemeral)); + }); }); bot.on_message_create([=, this] (const dpp::message_create_t& event) { // Update user cache @@ -401,16 +531,49 @@ public: for (const auto& [user_id, user] : users) { str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username); } + // Get channel config + BotChannelConfig channel_cfg; + // Attempt to find thread first... + bool in_bot_thread = false; + db << "SELECT model, instruct_mode FROM threads " + "WHERE id = ?;" + << std::to_string(msg.channel_id) + >> [&](const std::string& model_name, int instruct_mode) { + in_bot_thread = true; + channel_cfg.instruct_mode = instruct_mode; + // Find model + auto res = model_configs.find(model_name); + if (res == model_configs.end()) { + bot.message_create(dpp::message(msg.channel_id, texts.model_missing)); + return; + } + channel_cfg.model_name = &res->first; + channel_cfg.model_config = &res->second; + }; + // Otherwise just fall back to the default model config if allowed + if (!in_bot_thread) { + if (config.threads_only) return; + channel_cfg.model_name = &config.default_inference_model; + channel_cfg.model_config = config.default_inference_model_cfg; + } // Append message - tPool.submit(std::bind(&Bot::prompt_add_msg, this, msg)); + thread_pool.submit([=, this] () { + prompt_add_msg(msg, channel_cfg); + }); // Handle message somehow... - if (msg.content == "!trigger") { + if (msg.content == "!store") { + llm_pool.store_all(); //DEBUG +# warning DEBUG CODE!!! + } else if (in_bot_thread) { + // Send a reply + enqueue_reply(msg.channel_id, channel_cfg); + } else if (msg.content == "!trigger") { // Delete message bot.message_delete(msg.id, msg.channel_id); // Send a reply - enqueue_reply(msg.channel_id); + enqueue_reply(msg.channel_id, channel_cfg); } else { - attempt_reply(msg); + attempt_reply(msg, channel_cfg); } } catch (const std::exception& e) { std::cerr << "Warning: " << e.what() << std::endl; @@ -426,6 +589,31 @@ public: }; +bool parse_bool(const std::string& value) { + if (value == "true") + return true; + if (value == "false") + return false; + std::cerr << "Failed to parse configuration file: Unknown bool (true/false): " << value << std::endl; + exit(-4); +} +Bot::ModelConfig::InstructModePolicy parse_instruct_mode_policy(const std::string& value) { + if (value == "allow") + return Bot::ModelConfig::InstructModePolicy::Allow; + if (value == "force") + return Bot::ModelConfig::InstructModePolicy::Force; + if (value == "forbid") + return Bot::ModelConfig::InstructModePolicy::Forbid; + std::cerr << "Failed to parse model configuration file: Unknown instruct mode policy (allow/force/forbid): " << value << std::endl; + exit(-4); +} + +bool file_exists(const auto& p) { + // Make sure we don't respond to some file that is actually called "none"... + if (p == "none") return false; + return std::filesystem::exists(p); +} + int main(int argc, char **argv) { // Check arguments if (argc < 2) { @@ -433,12 +621,12 @@ int main(int argc, char **argv) { return -1; } - // Parse configuration + // Parse main configuration Bot::Configuration cfg; std::ifstream cfgf(argv[1]); if (!cfgf) { std::cerr << "Failed to open configuration file: " << argv[1] << std::endl; - return -1; + exit(-1); } for (std::string key; cfgf >> key;) { // Read value @@ -451,12 +639,16 @@ int main(int argc, char **argv) { cfg.token = std::move(value); } else if (key == "language") { cfg.language = std::move(value); - } else if (key == "inference_model") { - cfg.inference_model = std::move(value); + } else if (key == "default_inference_model") { + cfg.default_inference_model = std::move(value); } else if (key == "translation_model") { cfg.translation_model = std::move(value); } else if (key == "prompt_file") { cfg.prompt_file = std::move(value); + } else if (key == "instruct_prompt_file") { + cfg.instruct_prompt_file = std::move(value); + } else if (key == "models_dir") { + cfg.models_dir = std::move(value); } else if (key == "pool_size") { cfg.pool_size = std::stoi(value); } else if (key == "threads") { @@ -464,17 +656,113 @@ int main(int argc, char **argv) { } else if (key == "ctx_size") { cfg.ctx_size = std::stoi(value); } else if (key == "mlock") { - cfg.mlock = (value=="true")?true:false; + cfg.mlock = parse_bool(value); + } else if (key == "threads_only") { + cfg.threads_only = parse_bool(value); } else if (key == "persistance") { - cfg.persistance = (value=="true")?true:false; + cfg.persistance = parse_bool(value); } else if (!key.empty() && key[0] != '#') { std::cerr << "Failed to parse configuration file: Unknown key: " << key << std::endl; - return -2; + exit(-3); } } + // Parse model configurations + std::unordered_map models; + std::filesystem::path models_dir(cfg.models_dir); + bool allow_non_instruct = false; + for (const auto& file : std::filesystem::directory_iterator(models_dir)) { + // Check that file is model config + if (file.is_directory() || + file.path().filename().extension() != ".txt") continue; + // Get model name + auto model_name = file.path().filename().string(); + model_name.erase(model_name.size()-4, 4); + clean_command_name(model_name); + // Parse model config + Bot::ModelConfig model_cfg; + std::ifstream cfgf(file.path()); + if (!cfgf) { + std::cerr << "Failed to open model configuration file: " << file << std::endl; + exit(-2); + } + std::string filename; + for (std::string key; cfgf >> key;) { + // Read value + std::string value; + std::getline(cfgf, value); + // Erase all leading spaces + while (!value.empty() && (value[0] == ' ' || value[0] == '\t')) value.erase(0, 1); + // Check key and ignore comment lines + if (key == "filename") { + filename = std::move(value); + } else if (key == "user_prompt") { + model_cfg.user_prompt = std::move(value); + } else if (key == "bot_prompt") { + model_cfg.bot_prompt = std::move(value); + } else if (key == "instruct_mode_policy") { + model_cfg.instruct_mode_policy = parse_instruct_mode_policy(value); + } else if (!key.empty() && key[0] != '#') { + std::cerr << "Failed to parse model configuration file: Unknown key: " << key << std::endl; + exit(-3); + } + } + // Get full path + model_cfg.weight_path = file.path().parent_path()/filename; + // Safety checks + if (filename.empty() || !file_exists(model_cfg.weight_path)) { + std::cerr << "Failed to parse model configuration file: Invalid weight filename: " << model_name << std::endl; + exit(-8); + } + if (model_cfg.instruct_mode_policy != Bot::ModelConfig::InstructModePolicy::Forbid && + (model_cfg.user_prompt.empty() || model_cfg.bot_prompt.empty())) { + std::cerr << "Failed to parse model configuration file: Instruct mode allowed but user prompt and bot prompt not given: " << model_name << std::endl; + exit(-9); + } + if (model_cfg.instruct_mode_policy != Bot::ModelConfig::InstructModePolicy::Force) { + allow_non_instruct = true; + } + // Add model to list + const auto& [stored_model_name, stored_model_cfg] = *models.emplace(std::move(model_name), std::move(model_cfg)).first; + // Set model pointer in config + if (stored_model_name == cfg.default_inference_model) + cfg.default_inference_model_cfg = &stored_model_cfg; + if (stored_model_name == cfg.translation_model) + cfg.translation_model_cfg = &stored_model_cfg; + } + + // Safety checks + if (cfg.language != "EN") { + if (cfg.translation_model_cfg == nullptr) { + std::cerr << "Translation model required for non-english language, but is invalid" << std::endl; + exit(-5); + } + if (cfg.translation_model_cfg->instruct_mode_policy == Bot::ModelConfig::InstructModePolicy::Force) { + std::cerr << "Translation model is required to not have instruct mode forced" << std::endl; + exit(-10); + } + } + if (allow_non_instruct && !file_exists(cfg.prompt_file)) { + std::cerr << "Prompt file required when allowing non-instruct-mode use, but is invalid" << std::endl; + exit(-11); + } + if (!cfg.threads_only) { + if (cfg.default_inference_model_cfg == nullptr) { + std::cerr << "Default model required if not threads only, but is invalid" << std::endl; + exit(-6); + } + if (cfg.default_inference_model_cfg->instruct_mode_policy == Bot::ModelConfig::InstructModePolicy::Force) { + std::cerr << "Default model must not have instruct mode forced if not threads only" << std::endl; + exit(-7); + } + } + + // Clean model names in config + clean_command_name(cfg.default_inference_model); + clean_command_name(cfg.translation_model); + // Construct and configure bot - Bot bot(cfg); + Bot bot(cfg, models); // Start bot bot.start(); diff --git a/sqlite_modern_cpp/License.txt b/sqlite_modern_cpp/License.txt new file mode 100644 index 0000000..595b1d6 --- /dev/null +++ b/sqlite_modern_cpp/License.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 aminroosta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sqlite_modern_cpp/sqlite_modern_cpp.h b/sqlite_modern_cpp/sqlite_modern_cpp.h new file mode 100644 index 0000000..344dedc --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp.h @@ -0,0 +1,1053 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#define MODERN_SQLITE_VERSION 3002008 + +#ifdef __has_include +#if __cplusplus > 201402 && __has_include() +#define MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#elif __has_include() +#define MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT +#endif +#endif + +#ifdef __has_include +#if __cplusplus > 201402 && __has_include() +#define MODERN_SQLITE_STD_VARIANT_SUPPORT +#endif +#endif + +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#include +#endif + +#ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT +#include +#define MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#endif + +#ifdef _MODERN_SQLITE_BOOST_OPTIONAL_SUPPORT +#include +#endif + +#ifdef __ENVIRONMENT_IPHONE_OS_VERSION_MIN_REQUIRED__ +#if __ENVIRONMENT_IPHONE_OS_VERSION_MIN_REQUIRED__ < 100000 +#undef __cpp_lib_uncaught_exceptions +#endif +#endif + +#include + +#include "sqlite_modern_cpp/errors.h" +#include "sqlite_modern_cpp/utility/function_traits.h" +#include "sqlite_modern_cpp/utility/uncaught_exceptions.h" +#include "sqlite_modern_cpp/utility/utf16_utf8.h" + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT +#include "sqlite_modern_cpp/utility/variant.h" +#endif + +namespace sqlite { + + // std::optional support for NULL values + #ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + template + using optional = std::experimental::optional; + #else + template + using optional = std::optional; + #endif + #endif + + class database; + class database_binder; + + template class binder; + + typedef std::shared_ptr connection_type; + + template::value == Element)> struct tuple_iterate { + static void iterate(Tuple& t, database_binder& db) { + get_col_from_db(db, Element, std::get(t)); + tuple_iterate::iterate(t, db); + } + }; + + template struct tuple_iterate { + static void iterate(Tuple&, database_binder&) {} + }; + + class database_binder { + + public: + // database_binder is not copyable + database_binder() = delete; + database_binder(const database_binder& other) = delete; + database_binder& operator=(const database_binder&) = delete; + + database_binder(database_binder&& other) : + _db(std::move(other._db)), + _stmt(std::move(other._stmt)), + _inx(other._inx), execution_started(other.execution_started) { } + + void execute() { + _start_execute(); + int hresult; + + while((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) {} + + if(hresult != SQLITE_DONE) { + errors::throw_sqlite_error(hresult, sql()); + } + } + + std::string sql() { +#if SQLITE_VERSION_NUMBER >= 3014000 + auto sqlite_deleter = [](void *ptr) {sqlite3_free(ptr);}; + std::unique_ptr str(sqlite3_expanded_sql(_stmt.get()), sqlite_deleter); + return str ? str.get() : original_sql(); +#else + return original_sql(); +#endif + } + + std::string original_sql() { + return sqlite3_sql(_stmt.get()); + } + + void used(bool state) { + if(!state) { + // We may have to reset first if we haven't done so already: + _next_index(); + --_inx; + } + execution_started = state; + } + bool used() const { return execution_started; } + + private: + std::shared_ptr _db; + std::unique_ptr _stmt; + utility::UncaughtExceptionDetector _has_uncaught_exception; + + int _inx; + + bool execution_started = false; + + int _next_index() { + if(execution_started && !_inx) { + sqlite3_reset(_stmt.get()); + sqlite3_clear_bindings(_stmt.get()); + } + return ++_inx; + } + void _start_execute() { + _next_index(); + _inx = 0; + used(true); + } + + void _extract(std::function call_back) { + int hresult; + _start_execute(); + + while((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) { + call_back(); + } + + if(hresult != SQLITE_DONE) { + errors::throw_sqlite_error(hresult, sql()); + } + } + + void _extract_single_value(std::function call_back) { + int hresult; + _start_execute(); + + if((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) { + call_back(); + } else if(hresult == SQLITE_DONE) { + throw errors::no_rows("no rows to extract: exactly 1 row expected", sql(), SQLITE_DONE); + } + + if((hresult = sqlite3_step(_stmt.get())) == SQLITE_ROW) { + throw errors::more_rows("not all rows extracted", sql(), SQLITE_ROW); + } + + if(hresult != SQLITE_DONE) { + errors::throw_sqlite_error(hresult, sql()); + } + } + + sqlite3_stmt* _prepare(const std::u16string& sql) { + return _prepare(utility::utf16_to_utf8(sql)); + } + + sqlite3_stmt* _prepare(const std::string& sql) { + int hresult; + sqlite3_stmt* tmp = nullptr; + const char *remaining; + hresult = sqlite3_prepare_v2(_db.get(), sql.data(), -1, &tmp, &remaining); + if(hresult != SQLITE_OK) errors::throw_sqlite_error(hresult, sql); + if(!std::all_of(remaining, sql.data() + sql.size(), [](char ch) {return std::isspace(ch);})) + throw errors::more_statements("Multiple semicolon separated statements are unsupported", sql); + return tmp; + } + + template + struct is_sqlite_value : public std::integral_constant< + bool, + std::is_floating_point::value + || std::is_integral::value + || std::is_same::value + || std::is_same::value + || std::is_same::value + > { }; + template + struct is_sqlite_value< std::vector > : public std::integral_constant< + bool, + std::is_floating_point::value + || std::is_integral::value + || std::is_same::value + > { }; +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template + struct is_sqlite_value< std::variant > : public std::integral_constant< + bool, + true + > { }; +#endif + + + /* for vector support */ + template friend database_binder& operator <<(database_binder& db, const std::vector& val); + template friend void get_col_from_db(database_binder& db, int inx, std::vector& val); + /* for nullptr & unique_ptr support */ + friend database_binder& operator <<(database_binder& db, std::nullptr_t); + template friend database_binder& operator <<(database_binder& db, const std::unique_ptr& val); + template friend void get_col_from_db(database_binder& db, int inx, std::unique_ptr& val); +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template friend database_binder& operator <<(database_binder& db, const std::variant& val); + template friend void get_col_from_db(database_binder& db, int inx, std::variant& val); +#endif + template friend T operator++(database_binder& db, int); + // Overload instead of specializing function templates (http://www.gotw.ca/publications/mill17.htm) + friend database_binder& operator<<(database_binder& db, const int& val); + friend void get_col_from_db(database_binder& db, int inx, int& val); + friend database_binder& operator <<(database_binder& db, const sqlite_int64& val); + friend void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i); + friend database_binder& operator <<(database_binder& db, const float& val); + friend void get_col_from_db(database_binder& db, int inx, float& f); + friend database_binder& operator <<(database_binder& db, const double& val); + friend void get_col_from_db(database_binder& db, int inx, double& d); + friend void get_col_from_db(database_binder& db, int inx, std::string & s); + friend database_binder& operator <<(database_binder& db, const std::string& txt); + friend void get_col_from_db(database_binder& db, int inx, std::u16string & w); + friend database_binder& operator <<(database_binder& db, const std::u16string& txt); + + +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template friend database_binder& operator <<(database_binder& db, const optional& val); + template friend void get_col_from_db(database_binder& db, int inx, optional& o); +#endif + +#ifdef _MODERN_SQLITE_BOOST_OPTIONAL_SUPPORT + template friend database_binder& operator <<(database_binder& db, const boost::optional& val); + template friend void get_col_from_db(database_binder& db, int inx, boost::optional& o); +#endif + + public: + + database_binder(std::shared_ptr db, std::u16string const & sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + database_binder(std::shared_ptr db, std::string const & sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + ~database_binder() noexcept(false) { + /* Will be executed if no >>op is found, but not if an exception + is in mid flight */ + if(!used() && !_has_uncaught_exception && _stmt) { + execute(); + } + } + + template + typename std::enable_if::value, void>::type operator>>( + Result& value) { + this->_extract_single_value([&value, this] { + get_col_from_db(*this, 0, value); + }); + } + + template + void operator>>(std::tuple&& values) { + this->_extract_single_value([&values, this] { + tuple_iterate>::iterate(values, *this); + }); + } + + template + typename std::enable_if::value, void>::type operator>>( + Function&& func) { + typedef utility::function_traits traits; + + this->_extract([&func, this]() { + binder::run(*this, func); + }); + } + }; + + namespace sql_function_binder { + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + } + + enum class OpenFlags { + READONLY = SQLITE_OPEN_READONLY, + READWRITE = SQLITE_OPEN_READWRITE, + CREATE = SQLITE_OPEN_CREATE, + NOMUTEX = SQLITE_OPEN_NOMUTEX, + FULLMUTEX = SQLITE_OPEN_FULLMUTEX, + SHAREDCACHE = SQLITE_OPEN_SHAREDCACHE, + PRIVATECACH = SQLITE_OPEN_PRIVATECACHE, + URI = SQLITE_OPEN_URI + }; + inline OpenFlags operator|(const OpenFlags& a, const OpenFlags& b) { + return static_cast(static_cast(a) | static_cast(b)); + } + enum class Encoding { + ANY = SQLITE_ANY, + UTF8 = SQLITE_UTF8, + UTF16 = SQLITE_UTF16 + }; + struct sqlite_config { + OpenFlags flags = OpenFlags::READWRITE | OpenFlags::CREATE; + const char *zVfs = nullptr; + Encoding encoding = Encoding::ANY; + }; + + class database { + protected: + std::shared_ptr _db; + + public: + database(const std::string &db_name, const sqlite_config &config = {}): _db(nullptr) { + sqlite3* tmp = nullptr; + auto ret = sqlite3_open_v2(db_name.data(), &tmp, static_cast(config.flags), config.zVfs); + _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. + if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret); + sqlite3_extended_result_codes(_db.get(), true); + if(config.encoding == Encoding::UTF16) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(const std::u16string &db_name, const sqlite_config &config = {}): _db(nullptr) { + auto db_name_utf8 = utility::utf16_to_utf8(db_name); + sqlite3* tmp = nullptr; + auto ret = sqlite3_open_v2(db_name_utf8.data(), &tmp, static_cast(config.flags), config.zVfs); + _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. + if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret); + sqlite3_extended_result_codes(_db.get(), true); + if(config.encoding != Encoding::UTF8) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(std::shared_ptr db): + _db(db) {} + + database_binder operator<<(const std::string& sql) { + return database_binder(_db, sql); + } + + database_binder operator<<(const char* sql) { + return *this << std::string(sql); + } + + database_binder operator<<(const std::u16string& sql) { + return database_binder(_db, sql); + } + + database_binder operator<<(const char16_t* sql) { + return *this << std::u16string(sql); + } + + connection_type connection() const { return _db; } + + sqlite3_int64 last_insert_rowid() const { + return sqlite3_last_insert_rowid(_db.get()); + } + + template + void define(const std::string &name, Function&& func) { + typedef utility::function_traits traits; + + auto funcPtr = new auto(std::forward(func)); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity, SQLITE_UTF8, funcPtr, + sql_function_binder::scalar::type>, + nullptr, nullptr, [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result); + } + + template + void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { + typedef utility::function_traits traits; + using ContextType = typename std::remove_reference>::type; + + auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, + sql_function_binder::step::type>, + sql_function_binder::final::type>, + [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result); + } + + }; + + template + class binder { + private: + template < + typename Function, + std::size_t Index + > + using nth_argument_type = typename utility::function_traits< + Function + >::template argument; + + public: + // `Boundary` needs to be defaulted to `Count` so that the `run` function + // template is not implicitly instantiated on class template instantiation. + // Look up section 14.7.1 _Implicit instantiation_ of the ISO C++14 Standard + // and the [dicussion](https://github.com/aminroosta/sqlite_modern_cpp/issues/8) + // on Github. + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) < Boundary), void>::type run( + database_binder& db, + Function&& function, + Values&&... values + ) { + typename std::remove_cv>::type>::type value{}; + get_col_from_db(db, sizeof...(Values), value); + + run(db, function, std::forward(values)..., std::move(value)); + } + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run( + database_binder&, + Function&& function, + Values&&... values + ) { + function(std::move(values)...); + } + }; + + // int + inline database_binder& operator<<(database_binder& db, const int& val) { + int hresult; + if((hresult = sqlite3_bind_int(db._stmt.get(), db._next_index(), val)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + return db; + } + inline void store_result_in_db(sqlite3_context* db, const int& val) { + sqlite3_result_int(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, int& val) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + val = 0; + } else { + val = sqlite3_column_int(db._stmt.get(), inx); + } + } + inline void get_val_from_db(sqlite3_value *value, int& val) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + val = 0; + } else { + val = sqlite3_value_int(value); + } + } + + // sqlite_int64 + inline database_binder& operator <<(database_binder& db, const sqlite_int64& val) { + int hresult; + if((hresult = sqlite3_bind_int64(db._stmt.get(), db._next_index(), val)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { + sqlite3_result_int64(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + i = 0; + } else { + i = sqlite3_column_int64(db._stmt.get(), inx); + } + } + inline void get_val_from_db(sqlite3_value *value, sqlite3_int64& i) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + i = 0; + } else { + i = sqlite3_value_int64(value); + } + } + + // float + inline database_binder& operator <<(database_binder& db, const float& val) { + int hresult; + if((hresult = sqlite3_bind_double(db._stmt.get(), db._next_index(), double(val))) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const float& val) { + sqlite3_result_double(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, float& f) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + f = 0; + } else { + f = float(sqlite3_column_double(db._stmt.get(), inx)); + } + } + inline void get_val_from_db(sqlite3_value *value, float& f) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + f = 0; + } else { + f = float(sqlite3_value_double(value)); + } + } + + // double + inline database_binder& operator <<(database_binder& db, const double& val) { + int hresult; + if((hresult = sqlite3_bind_double(db._stmt.get(), db._next_index(), val)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const double& val) { + sqlite3_result_double(db, val); + } + inline void get_col_from_db(database_binder& db, int inx, double& d) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + d = 0; + } else { + d = sqlite3_column_double(db._stmt.get(), inx); + } + } + inline void get_val_from_db(sqlite3_value *value, double& d) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + d = 0; + } else { + d = sqlite3_value_double(value); + } + } + + // vector + template inline database_binder& operator<<(database_binder& db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + int hresult; + if((hresult = sqlite3_bind_blob(db._stmt.get(), db._next_index(), buf, bytes, SQLITE_TRANSIENT)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + return db; + } + template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); + } + template inline void get_col_from_db(database_binder& db, int inx, std::vector& vec) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + vec.clear(); + } else { + int bytes = sqlite3_column_bytes(db._stmt.get(), inx); + T const* buf = reinterpret_cast(sqlite3_column_blob(db._stmt.get(), inx)); + vec = std::vector(buf, buf + bytes/sizeof(T)); + } + } + template inline void get_val_from_db(sqlite3_value *value, std::vector& vec) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + vec.clear(); + } else { + int bytes = sqlite3_value_bytes(value); + T const* buf = reinterpret_cast(sqlite3_value_blob(value)); + vec = std::vector(buf, buf + bytes/sizeof(T)); + } + } + + /* for nullptr support */ + inline database_binder& operator <<(database_binder& db, std::nullptr_t) { + int hresult; + if((hresult = sqlite3_bind_null(db._stmt.get(), db._next_index())) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + return db; + } + inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { + sqlite3_result_null(db); + } + /* for nullptr support */ + template inline database_binder& operator <<(database_binder& db, const std::unique_ptr& val) { + if(val) + db << *val; + else + db << nullptr; + return db; + } + + /* for unique_ptr support */ + template inline void get_col_from_db(database_binder& db, int inx, std::unique_ptr& _ptr_) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + _ptr_ = nullptr; + } else { + auto underling_ptr = new T(); + get_col_from_db(db, inx, *underling_ptr); + _ptr_.reset(underling_ptr); + } + } + template inline void get_val_from_db(sqlite3_value *value, std::unique_ptr& _ptr_) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + _ptr_ = nullptr; + } else { + auto underling_ptr = new T(); + get_val_from_db(value, *underling_ptr); + _ptr_.reset(underling_ptr); + } + } + + // std::string + inline void get_col_from_db(database_binder& db, int inx, std::string & s) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + s = std::string(); + } else { + sqlite3_column_bytes(db._stmt.get(), inx); + s = std::string(reinterpret_cast(sqlite3_column_text(db._stmt.get(), inx))); + } + } + inline void get_val_from_db(sqlite3_value *value, std::string & s) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + s = std::string(); + } else { + sqlite3_value_bytes(value); + s = std::string(reinterpret_cast(sqlite3_value_text(value))); + } + } + + // Convert char* to string to trigger op<<(..., const std::string ) + template inline database_binder& operator <<(database_binder& db, const char(&STR)[N]) { return db << std::string(STR); } + template inline database_binder& operator <<(database_binder& db, const char16_t(&STR)[N]) { return db << std::u16string(STR); } + + inline database_binder& operator <<(database_binder& db, const std::string& txt) { + int hresult; + if((hresult = sqlite3_bind_text(db._stmt.get(), db._next_index(), txt.data(), -1, SQLITE_TRANSIENT)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::string& val) { + sqlite3_result_text(db, val.data(), -1, SQLITE_TRANSIENT); + } + // std::u16string + inline void get_col_from_db(database_binder& db, int inx, std::u16string & w) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + w = std::u16string(); + } else { + sqlite3_column_bytes16(db._stmt.get(), inx); + w = std::u16string(reinterpret_cast(sqlite3_column_text16(db._stmt.get(), inx))); + } + } + inline void get_val_from_db(sqlite3_value *value, std::u16string & w) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + w = std::u16string(); + } else { + sqlite3_value_bytes16(value); + w = std::u16string(reinterpret_cast(sqlite3_value_text16(value))); + } + } + + + inline database_binder& operator <<(database_binder& db, const std::u16string& txt) { + int hresult; + if((hresult = sqlite3_bind_text16(db._stmt.get(), db._next_index(), txt.data(), -1, SQLITE_TRANSIENT)) != SQLITE_OK) { + errors::throw_sqlite_error(hresult, db.sql()); + } + + return db; + } + inline void store_result_in_db(sqlite3_context* db, const std::u16string& val) { + sqlite3_result_text16(db, val.data(), -1, SQLITE_TRANSIENT); + } + + // Other integer types + template::value>::type> + inline database_binder& operator <<(database_binder& db, const Integral& val) { + return db << static_cast(val); + } + template::type>> + inline void store_result_in_db(sqlite3_context* db, const Integral& val) { + store_result_in_db(db, static_cast(val)); + } + template::value>::type> + inline void get_col_from_db(database_binder& db, int inx, Integral& val) { + sqlite3_int64 i; + get_col_from_db(db, inx, i); + val = i; + } + template::value>::type> + inline void get_val_from_db(sqlite3_value *value, Integral& val) { + sqlite3_int64 i; + get_val_from_db(value, i); + val = i; + } + + // std::optional support for NULL values +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template inline database_binder& operator <<(database_binder& db, const optional& val) { + if(val) { + return db << std::move(*val); + } else { + return db << nullptr; + } + } + template inline void store_result_in_db(sqlite3_context* db, const optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } + + template inline void get_col_from_db(database_binder& db, int inx, optional& o) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + o = std::experimental::nullopt; + #else + o.reset(); + #endif + } else { + OptionalT v; + get_col_from_db(db, inx, v); + o = std::move(v); + } + } + template inline void get_val_from_db(sqlite3_value *value, optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + o = std::experimental::nullopt; + #else + o.reset(); + #endif + } else { + OptionalT v; + get_val_from_db(value, v); + o = std::move(v); + } + } +#endif + + // boost::optional support for NULL values +#ifdef _MODERN_SQLITE_BOOST_OPTIONAL_SUPPORT + template inline database_binder& operator <<(database_binder& db, const boost::optional& val) { + if(val) { + return db << std::move(*val); + } else { + return db << nullptr; + } + } + template inline void store_result_in_db(sqlite3_context* db, const boost::optional& val) { + if(val) { + store_result_in_db(db, *val); + } + sqlite3_result_null(db); + } + + template inline void get_col_from_db(database_binder& db, int inx, boost::optional& o) { + if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) { + o.reset(); + } else { + BoostOptionalT v; + get_col_from_db(db, inx, v); + o = std::move(v); + } + } + template inline void get_val_from_db(sqlite3_value *value, boost::optional& o) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + o.reset(); + } else { + BoostOptionalT v; + get_val_from_db(value, v); + o = std::move(v); + } + } +#endif + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template inline database_binder& operator <<(database_binder& db, const std::variant& val) { + std::visit([&](auto &&opt) {db << std::forward(opt);}, val); + return db; + } + template inline void store_result_in_db(sqlite3_context* db, const std::variant& val) { + std::visit([&](auto &&opt) {store_result_in_db(db, std::forward(opt));}, val); + } + template inline void get_col_from_db(database_binder& db, int inx, std::variant& val) { + utility::variant_select(sqlite3_column_type(db._stmt.get(), inx))([&](auto v) { + get_col_from_db(db, inx, v); + val = std::move(v); + }); + } + template inline void get_val_from_db(sqlite3_value *value, std::variant& val) { + utility::variant_select(sqlite3_value_type(value))([&](auto v) { + get_val_from_db(value, v); + val = std::move(v); + }); + } +#endif + + // Some ppl are lazy so we have a operator for proper prep. statemant handling. + void inline operator++(database_binder& db, int) { db.execute(); } + + // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) + template database_binder&& operator << (database_binder&& db, const T& val) { db << val; return std::move(db); } + + namespace sql_function_binder { + template + struct AggregateCtxt { + T obj; + bool constructed = true; + }; + + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + try { + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + step(db, count, vals, ctxt->obj); + return; + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits< + typename Functions::first_type + >::template argument + >::type + >::type value{}; + get_val_from_db(vals[sizeof...(Values) - 1], value); + + step(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + static_cast(sqlite3_user_data(db))->first(std::forward(values)...); + } + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + try { + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + store_result_in_db(db, + static_cast(sqlite3_user_data(db))->second(ctxt->obj)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits::template argument + >::type + >::type value{}; + get_val_from_db(vals[sizeof...(Values)], value); + + scalar(db, count, vals, std::forward(values)..., std::move(value)); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + try { + store_result_in_db(db, + (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); + } catch(sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/errors.h b/sqlite_modern_cpp/sqlite_modern_cpp/errors.h new file mode 100644 index 0000000..2b9ab75 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/errors.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include + +namespace sqlite { + + class sqlite_exception: public std::runtime_error { + public: + sqlite_exception(const char* msg, std::string sql, int code = -1): runtime_error(msg), code(code), sql(sql) {} + sqlite_exception(int code, std::string sql): runtime_error(sqlite3_errstr(code)), code(code), sql(sql) {} + int get_code() const {return code & 0xFF;} + int get_extended_code() const {return code;} + std::string get_sql() const {return sql;} + private: + int code; + std::string sql; + }; + + namespace errors { + //One more or less trivial derived error class for each SQLITE error. + //Note the following are not errors so have no classes: + //SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE + // + //Note these names are exact matches to the names of the SQLITE error codes. +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + class name: public sqlite_exception { using sqlite_exception::sqlite_exception; };\ + derived +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + class base ## _ ## sub: public base { using base::base; }; +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + + //Some additional errors are here for the C++ interface + class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement + class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + + static void throw_sqlite_error(const int& error_code, const std::string &sql = "") { + switch(error_code & 0xFF) { +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + case SQLITE_ ## NAME: switch(error_code) { \ + derived \ + default: throw name(error_code, sql); \ + } +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql); +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + default: throw sqlite_exception(error_code, sql); + } + } + } + namespace exceptions = errors; +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h b/sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h new file mode 100644 index 0000000..5dfa0d3 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h @@ -0,0 +1,93 @@ +#if SQLITE_VERSION_NUMBER < 3010000 +#define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8)) +#define SQLITE_IOERR_AUTH (SQLITE_IOERR | (28<<8)) +#define SQLITE_AUTH_USER (SQLITE_AUTH | (1<<8)) +#endif +SQLITE_MODERN_CPP_ERROR_CODE(ERROR,error,) +SQLITE_MODERN_CPP_ERROR_CODE(INTERNAL,internal,) +SQLITE_MODERN_CPP_ERROR_CODE(PERM,perm,) +SQLITE_MODERN_CPP_ERROR_CODE(ABORT,abort, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(ABORT,ROLLBACK,abort,rollback) +) +SQLITE_MODERN_CPP_ERROR_CODE(BUSY,busy, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,RECOVERY,busy,recovery) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,SNAPSHOT,busy,snapshot) +) +SQLITE_MODERN_CPP_ERROR_CODE(LOCKED,locked, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(LOCKED,SHAREDCACHE,locked,sharedcache) +) +SQLITE_MODERN_CPP_ERROR_CODE(NOMEM,nomem,) +SQLITE_MODERN_CPP_ERROR_CODE(READONLY,readonly,) +SQLITE_MODERN_CPP_ERROR_CODE(INTERRUPT,interrupt,) +SQLITE_MODERN_CPP_ERROR_CODE(IOERR,ioerr, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,READ,ioerr,read) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHORT_READ,ioerr,short_read) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,WRITE,ioerr,write) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSYNC,ioerr,fsync) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_FSYNC,ioerr,dir_fsync) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,TRUNCATE,ioerr,truncate) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSTAT,ioerr,fstat) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,UNLOCK,ioerr,unlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,RDLOCK,ioerr,rdlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE,ioerr,delete) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,BLOCKED,ioerr,blocked) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,NOMEM,ioerr,nomem) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,ACCESS,ioerr,access) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CHECKRESERVEDLOCK,ioerr,checkreservedlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,LOCK,ioerr,lock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CLOSE,ioerr,close) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_CLOSE,ioerr,dir_close) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMOPEN,ioerr,shmopen) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMSIZE,ioerr,shmsize) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMLOCK,ioerr,shmlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMMAP,ioerr,shmmap) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SEEK,ioerr,seek) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE_NOENT,ioerr,delete_noent) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,MMAP,ioerr,mmap) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,GETTEMPPATH,ioerr,gettemppath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CONVPATH,ioerr,convpath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,VNODE,ioerr,vnode) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,AUTH,ioerr,auth) +) +SQLITE_MODERN_CPP_ERROR_CODE(CORRUPT,corrupt, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CORRUPT,VTAB,corrupt,vtab) +) +SQLITE_MODERN_CPP_ERROR_CODE(NOTFOUND,notfound,) +SQLITE_MODERN_CPP_ERROR_CODE(FULL,full,) +SQLITE_MODERN_CPP_ERROR_CODE(CANTOPEN,cantopen, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,NOTEMPDIR,cantopen,notempdir) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,ISDIR,cantopen,isdir) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,FULLPATH,cantopen,fullpath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,CONVPATH,cantopen,convpath) +) +SQLITE_MODERN_CPP_ERROR_CODE(PROTOCOL,protocol,) +SQLITE_MODERN_CPP_ERROR_CODE(EMPTY,empty,) +SQLITE_MODERN_CPP_ERROR_CODE(SCHEMA,schema,) +SQLITE_MODERN_CPP_ERROR_CODE(TOOBIG,toobig,) +SQLITE_MODERN_CPP_ERROR_CODE(CONSTRAINT,constraint, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,CHECK,constraint,check) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,COMMITHOOK,constraint,commithook) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FOREIGNKEY,constraint,foreignkey) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FUNCTION,constraint,function) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,NOTNULL,constraint,notnull) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,PRIMARYKEY,constraint,primarykey) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,TRIGGER,constraint,trigger) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,UNIQUE,constraint,unique) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,VTAB,constraint,vtab) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,ROWID,constraint,rowid) +) +SQLITE_MODERN_CPP_ERROR_CODE(MISMATCH,mismatch,) +SQLITE_MODERN_CPP_ERROR_CODE(MISUSE,misuse,) +SQLITE_MODERN_CPP_ERROR_CODE(NOLFS,nolfs,) +SQLITE_MODERN_CPP_ERROR_CODE(AUTH,auth, +) +SQLITE_MODERN_CPP_ERROR_CODE(FORMAT,format,) +SQLITE_MODERN_CPP_ERROR_CODE(RANGE,range,) +SQLITE_MODERN_CPP_ERROR_CODE(NOTADB,notadb,) +SQLITE_MODERN_CPP_ERROR_CODE(NOTICE,notice, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_WAL,notice,recover_wal) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_ROLLBACK,notice,recover_rollback) +) +SQLITE_MODERN_CPP_ERROR_CODE(WARNING,warning, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(WARNING,AUTOINDEX,warning,autoindex) +) diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/log.h b/sqlite_modern_cpp/sqlite_modern_cpp/log.h new file mode 100644 index 0000000..a8f7be2 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/log.h @@ -0,0 +1,101 @@ +#include "errors.h" + +#include + +#include +#include +#include + +namespace sqlite { + namespace detail { + template + using void_t = void; + template + struct is_callable : std::false_type {}; + template + struct is_callable()(std::declval()...))>> : std::true_type {}; + template + class FunctorOverload: public Functor, public FunctorOverload { + public: + template + FunctorOverload(Functor1 &&functor, Remaining &&... remaining): + Functor(std::forward(functor)), + FunctorOverload(std::forward(remaining)...) {} + using Functor::operator(); + using FunctorOverload::operator(); + }; + template + class FunctorOverload: public Functor { + public: + template + FunctorOverload(Functor1 &&functor): + Functor(std::forward(functor)) {} + using Functor::operator(); + }; + template + class WrapIntoFunctor: public Functor { + public: + template + WrapIntoFunctor(Functor1 &&functor): + Functor(std::forward(functor)) {} + using Functor::operator(); + }; + template + class WrapIntoFunctor { + ReturnType(*ptr)(Arguments...); + public: + WrapIntoFunctor(ReturnType(*ptr)(Arguments...)): ptr(ptr) {} + ReturnType operator()(Arguments... arguments) { return (*ptr)(std::forward(arguments)...); } + }; + inline void store_error_log_data_pointer(std::shared_ptr ptr) { + static std::shared_ptr stored; + stored = std::move(ptr); + } + template + std::shared_ptr::type> make_shared_inferred(T &&t) { + return std::make_shared::type>(std::forward(t)); + } + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler); + template + typename std::enable_if::value>::type + error_log(Handler &&handler); + template + typename std::enable_if=2>::type + error_log(Handler &&...handler) { + return error_log(detail::FunctorOverload::type>...>(std::forward(handler)...)); + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler) { + return error_log(std::forward(handler), [](const sqlite_exception&) {}); + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler) { + auto ptr = detail::make_shared_inferred([handler = std::forward(handler)](int error_code, const char *errstr) mutable { + switch(error_code & 0xFF) { +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + case SQLITE_ ## NAME: switch(error_code) { \ + derived \ + default: handler(errors::name(errstr, "", error_code)); \ + };break; +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + case SQLITE_ ## BASE ## _ ## SUB: \ + handler(errors::base ## _ ## sub(errstr, "", error_code)); \ + break; +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + default: handler(sqlite_exception(errstr, "", error_code)); \ + } + }); + + sqlite3_config(SQLITE_CONFIG_LOG, (void(*)(void*,int,const char*))[](void *functor, int error_code, const char *errstr) { + (*static_cast(functor))(error_code, errstr); + }, ptr.get()); + detail::store_error_log_data_pointer(std::move(ptr)); + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h b/sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h new file mode 100644 index 0000000..da0f018 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h @@ -0,0 +1,44 @@ +#pragma once + +#ifndef SQLITE_HAS_CODEC +#define SQLITE_HAS_CODEC +#endif + +#include "../sqlite_modern_cpp.h" + +namespace sqlite { + struct sqlcipher_config : public sqlite_config { + std::string key; + }; + + class sqlcipher_database : public database { + public: + sqlcipher_database(std::string db, const sqlcipher_config &config): database(db, config) { + set_key(config.key); + } + + sqlcipher_database(std::u16string db, const sqlcipher_config &config): database(db, config) { + set_key(config.key); + } + + void set_key(const std::string &key) { + if(auto ret = sqlite3_key(_db.get(), key.data(), key.size())) + errors::throw_sqlite_error(ret); + } + + void set_key(const std::string &key, const std::string &db_name) { + if(auto ret = sqlite3_key_v2(_db.get(), db_name.c_str(), key.data(), key.size())) + errors::throw_sqlite_error(ret); + } + + void rekey(const std::string &new_key) { + if(auto ret = sqlite3_rekey(_db.get(), new_key.data(), new_key.size())) + errors::throw_sqlite_error(ret); + } + + void rekey(const std::string &new_key, const std::string &db_name) { + if(auto ret = sqlite3_rekey_v2(_db.get(), db_name.c_str(), new_key.data(), new_key.size())) + errors::throw_sqlite_error(ret); + } + }; +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/function_traits.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/function_traits.h new file mode 100644 index 0000000..cd8fab0 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/function_traits.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +namespace sqlite { + namespace utility { + + template struct function_traits; + + template + struct function_traits : public function_traits< + decltype(&std::remove_reference::type::operator()) + > { }; + + template < + typename ClassType, + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(ClassType::*)(Arguments...) const + > : function_traits { }; + + /* support the non-const operator () + * this will work with user defined functors */ + template < + typename ClassType, + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(ClassType::*)(Arguments...) + > : function_traits { }; + + template < + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(*)(Arguments...) + > { + typedef ReturnType result_type; + + template + using argument = typename std::tuple_element< + Index, + std::tuple + >::type; + + static const std::size_t arity = sizeof...(Arguments); + }; + + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/uncaught_exceptions.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/uncaught_exceptions.h new file mode 100644 index 0000000..17d6326 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/uncaught_exceptions.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include + +namespace sqlite { + namespace utility { +#ifdef __cpp_lib_uncaught_exceptions + class UncaughtExceptionDetector { + public: + operator bool() { + return count != std::uncaught_exceptions(); + } + private: + int count = std::uncaught_exceptions(); + }; +#else + class UncaughtExceptionDetector { + public: + operator bool() { + return std::uncaught_exception(); + } + }; +#endif + } +} diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h new file mode 100644 index 0000000..ea21723 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include "../errors.h" + +namespace sqlite { + namespace utility { + inline std::string utf16_to_utf8(const std::u16string &input) { + struct : std::codecvt { + } codecvt; + std::mbstate_t state{}; + std::string result((std::max)(input.size() * 3 / 2, std::size_t(4)), '\0'); + const char16_t *remaining_input = input.data(); + std::size_t produced_output = 0; + while(true) { + char *used_output; + switch(codecvt.out(state, remaining_input, &input[input.size()], + remaining_input, &result[produced_output], + &result[result.size() - 1] + 1, used_output)) { + case std::codecvt_base::ok: + result.resize(used_output - result.data()); + return result; + case std::codecvt_base::noconv: + // This should be unreachable + case std::codecvt_base::error: + throw errors::invalid_utf16("Invalid UTF-16 input", ""); + case std::codecvt_base::partial: + if(used_output == result.data() + produced_output) + throw errors::invalid_utf16("Unexpected end of input", ""); + produced_output = used_output - result.data(); + result.resize( + result.size() + + (std::max)((&input[input.size()] - remaining_input) * 3 / 2, + std::ptrdiff_t(4))); + } + } + } + } // namespace utility +} // namespace sqlite diff --git a/sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h b/sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h new file mode 100644 index 0000000..11a8429 --- /dev/null +++ b/sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h @@ -0,0 +1,201 @@ +#pragma once + +#include "../errors.h" +#include +#include +#include + +namespace sqlite::utility { + template + struct VariantFirstNullable { + using type = void; + }; + template + struct VariantFirstNullable { + using type = typename VariantFirstNullable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstNullable, Options...> { + using type = std::optional; + }; +#endif + template + struct VariantFirstNullable, Options...> { + using type = std::unique_ptr; + }; + template + struct VariantFirstNullable { + using type = std::nullptr_t; + }; + template + inline void variant_select_null(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("NULL is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstNullable::type()); + } + } + + template + struct VariantFirstIntegerable { + using type = void; + }; + template + struct VariantFirstIntegerable { + using type = typename VariantFirstIntegerable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstIntegerable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstIntegerable::type>; + }; +#endif + template + struct VariantFirstIntegerable::type, T>>, std::unique_ptr, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstIntegerable::type>; + }; + template + struct VariantFirstIntegerable { + using type = int; + }; + template + struct VariantFirstIntegerable { + using type = sqlite_int64; + }; + template + inline auto variant_select_integer(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Integer is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstIntegerable::type()); + } + } + + template + struct VariantFirstFloatable { + using type = void; + }; + template + struct VariantFirstFloatable { + using type = typename VariantFirstFloatable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstFloatable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstFloatable::type>; + }; +#endif + template + struct VariantFirstFloatable, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstFloatable::type>; + }; + template + struct VariantFirstFloatable { + using type = float; + }; + template + struct VariantFirstFloatable { + using type = double; + }; + template + inline auto variant_select_float(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Real is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstFloatable::type()); + } + } + + template + struct VariantFirstTextable { + using type = void; + }; + template + struct VariantFirstTextable { + using type = typename VariantFirstTextable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstTextable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstTextable::type>; + }; +#endif + template + struct VariantFirstTextable, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstTextable::type>; + }; + template + struct VariantFirstTextable { + using type = std::string; + }; + template + struct VariantFirstTextable { + using type = std::u16string; + }; + template + inline void variant_select_text(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Text is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstTextable::type()); + } + } + + template + struct VariantFirstBlobable { + using type = void; + }; + template + struct VariantFirstBlobable { + using type = typename VariantFirstBlobable::type; + }; +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT + template + struct VariantFirstBlobable, Options...> { + using type = std::conditional_t::type, T>, std::optional, typename VariantFirstBlobable::type>; + }; +#endif + template + struct VariantFirstBlobable, Options...> { + using type = std::conditional_t::type, T>, std::unique_ptr, typename VariantFirstBlobable::type>; + }; + template + struct VariantFirstBlobable>, std::vector, Options...> { + using type = std::vector; + }; + template + inline auto variant_select_blob(Callback&&callback) { + if constexpr(std::is_same_v::type, void>) { + throw errors::mismatch("Blob is unsupported by this variant.", "", SQLITE_MISMATCH); + } else { + std::forward(callback)(typename VariantFirstBlobable::type()); + } + } + + template + inline auto variant_select(int type) { + return [type](auto &&callback) { + using Callback = decltype(callback); + switch(type) { + case SQLITE_NULL: + variant_select_null(std::forward(callback)); + break; + case SQLITE_INTEGER: + variant_select_integer(std::forward(callback)); + break; + case SQLITE_FLOAT: + variant_select_float(std::forward(callback)); + break; + case SQLITE_TEXT: + variant_select_text(std::forward(callback)); + break; + case SQLITE_BLOB: + variant_select_blob(std::forward(callback)); + break; + default:; + /* assert(false); */ + } + }; + } +}