mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Added tons of features including thread creation
This commit is contained in:
parent
a0edf72a43
commit
bf568c3c8a
18 changed files with 2127 additions and 122 deletions
|
@ -11,7 +11,7 @@ add_subdirectory(thread-pool)
|
|||
add_subdirectory(fmt)
|
||||
|
||||
add_executable(discord_llama main.cpp)
|
||||
target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool)
|
||||
target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool sqlite3)
|
||||
|
||||
install(TARGETS discord_llama
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
|
|
2
anyproc
2
anyproc
|
@ -1 +1 @@
|
|||
Subproject commit 6ba351b1173539254b2e641d19ca77c459a087a2
|
||||
Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707
|
|
@ -1,12 +1,18 @@
|
|||
token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg
|
||||
|
||||
# The following parameters are set to their defaults here and can be ommited
|
||||
models_dir models
|
||||
language EN
|
||||
inference_model 13B-ggml-model-quant.bin
|
||||
translation_model 13B-ggml-model-quant.bin
|
||||
prompt_file prompt.txt
|
||||
threads_only true
|
||||
|
||||
default_inference_model 13B-vanilla
|
||||
translation_model none
|
||||
|
||||
prompt_file none
|
||||
instruct_prompt_file none
|
||||
|
||||
persistance true
|
||||
mlock false
|
||||
pool_size 2
|
||||
threads 4
|
||||
persistance true
|
||||
ctx_size 1012
|
||||
|
|
2
example_models/13B-vanilla.txt
Normal file
2
example_models/13B-vanilla.txt
Normal file
|
@ -0,0 +1,2 @@
|
|||
filename 13B-ggml-model-quant.bin
|
||||
instruct_mode_policy forbid
|
4
example_models/13B-vicuna-1.0.txt
Normal file
4
example_models/13B-vicuna-1.0.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
filename ggml-vicuna-13b-1.0-q4_0.bin
|
||||
instruct_mode_policy force
|
||||
user_prompt ### Human:
|
||||
bot_prompt ### Assistant:
|
4
example_models/13B-vicuna-1.1.txt
Normal file
4
example_models/13B-vicuna-1.1.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
filename ggml-vicuna-13b-1.1-q4_0.bin
|
||||
instruct_mode_policy force
|
||||
user_prompt HUMAN:
|
||||
bot_prompt ASSISTANT:
|
4
example_models/gpt4all-unfiltered.txt
Normal file
4
example_models/gpt4all-unfiltered.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
filename gpt4all-lora-unfiltered-quantized.bin
|
||||
instruct_mode_policy allow
|
||||
user_prompt ### Instruction:
|
||||
bot_prompt ### Response:
|
520
main.cpp
520
main.cpp
|
@ -1,4 +1,5 @@
|
|||
#include "Timer.hpp"
|
||||
#include "sqlite_modern_cpp/sqlite_modern_cpp.h"
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
@ -48,9 +49,23 @@ void str_replace_in_place(std::string& subject, std::string_view search,
|
|||
}
|
||||
}
|
||||
|
||||
static
|
||||
void clean_command_name(std::string& value) {
|
||||
for (auto& c : value) {
|
||||
if (c == '.') c = '_';
|
||||
if (isalpha(c)) c = tolower(c);
|
||||
}
|
||||
}
|
||||
[[nodiscard]] static
|
||||
std::string clean_command_name(std::string_view input) {
|
||||
std::string fres(input);
|
||||
clean_command_name(fres);
|
||||
return fres;
|
||||
}
|
||||
|
||||
|
||||
class Bot {
|
||||
ThreadPool tPool{1};
|
||||
ThreadPool thread_pool{1};
|
||||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
LM::InferencePool llm_pool;
|
||||
|
@ -58,47 +73,64 @@ class Bot {
|
|||
std::vector<dpp::snowflake> my_messages;
|
||||
std::unordered_map<dpp::snowflake, dpp::user> users;
|
||||
std::thread::id llm_tid;
|
||||
sqlite::database db;
|
||||
|
||||
std::string_view language;
|
||||
dpp::cluster bot;
|
||||
|
||||
public:
|
||||
struct ModelConfig {
|
||||
std::string weight_path,
|
||||
user_prompt,
|
||||
bot_prompt;
|
||||
enum class InstructModePolicy {
|
||||
Allow = 0b11,
|
||||
Force = 0b10,
|
||||
Forbid = 0b01
|
||||
} instruct_mode_policy = InstructModePolicy::Allow;
|
||||
|
||||
bool is_instruct_mode_allowed() const {
|
||||
return static_cast<unsigned>(instruct_mode_policy) & 0b10;
|
||||
}
|
||||
bool is_non_instruct_mode_allowed() const {
|
||||
return static_cast<unsigned>(instruct_mode_policy) & 0b01;
|
||||
}
|
||||
};
|
||||
struct BotChannelConfig {
|
||||
const std::string *model_name;
|
||||
const ModelConfig *model_config;
|
||||
bool instruct_mode = false;
|
||||
};
|
||||
struct Configuration {
|
||||
std::string token,
|
||||
language = "EN",
|
||||
default_inference_model = "13B-vanilla",
|
||||
translation_model = "none",
|
||||
prompt_file = "none",
|
||||
instruct_prompt_file = "none",
|
||||
models_dir = "models";
|
||||
unsigned ctx_size = 1012,
|
||||
pool_size = 2,
|
||||
threads = 4,
|
||||
persistance = true;
|
||||
bool mlock = false,
|
||||
threads_only = true;
|
||||
const ModelConfig *default_inference_model_cfg = nullptr,
|
||||
*translation_model_cfg = nullptr;
|
||||
};
|
||||
|
||||
private:
|
||||
const Configuration& config;
|
||||
const std::unordered_map<std::string, ModelConfig>& model_configs;
|
||||
|
||||
struct Texts {
|
||||
std::string please_wait = "Please wait...",
|
||||
loading = "Loading...",
|
||||
initializing = "Initializing...",
|
||||
thread_create_fail = "Error: I couldn't create a thread here. Do I have enough permissions?",
|
||||
model_missing = "Error: The model that was used in this thread could no longer be found.",
|
||||
timeout = "Error: Timeout";
|
||||
bool translated = false;
|
||||
} texts;
|
||||
|
||||
inline static
|
||||
std::string create_text_progress_indicator(uint8_t percentage) {
|
||||
static constexpr uint8_t divisor = 3,
|
||||
width = 100 / divisor;
|
||||
// Progress bar percentage lookup
|
||||
const static auto indicator_lookup = [] () consteval {
|
||||
std::array<uint8_t, 101> fres;
|
||||
for (uint8_t it = 0; it != 101; it++) {
|
||||
fres[it] = it / divisor;
|
||||
}
|
||||
return fres;
|
||||
}();
|
||||
// Initialize string
|
||||
std::string fres;
|
||||
fres.resize(width+4);
|
||||
fres[0] = '`';
|
||||
fres[1] = '[';
|
||||
// Append progress
|
||||
const uint8_t bars = indicator_lookup[percentage];
|
||||
for (uint8_t it = 0; it != width; it++) {
|
||||
if (it < bars) fres[it+2] = '#';
|
||||
else fres[it+2] = ' ';
|
||||
}
|
||||
// Finalize and return string
|
||||
fres[width+2] = ']';
|
||||
fres[width+3] = '`';
|
||||
return fres;
|
||||
}
|
||||
|
||||
inline static
|
||||
bool show_console_progress(float progress) {
|
||||
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
|
||||
|
@ -178,27 +210,29 @@ class Bot {
|
|||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void llm_restart(LM::Inference& inference) {
|
||||
// Deserialize init cache
|
||||
std::ifstream f("init_cache", std::ios::binary);
|
||||
void llm_restart(LM::Inference& inference, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Deserialize init cache if not instruct mode without prompt file
|
||||
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return;
|
||||
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
|
||||
inference.deserialize(f);
|
||||
}
|
||||
// Must run in llama thread
|
||||
LM::Inference &llm_restart(dpp::snowflake id) {
|
||||
LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Get or create inference
|
||||
auto& inference = llm_pool.get_or_create_inference(id, config.inference_model, llm_get_params());
|
||||
llm_restart(inference);
|
||||
auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params());
|
||||
llm_restart(inference, channel_cfg);
|
||||
return inference;
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
LM::Inference &llm_get_inference(dpp::snowflake id) {
|
||||
LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
auto inference_opt = llm_pool.get_inference(id);
|
||||
if (!inference_opt.has_value()) {
|
||||
// Start new inference
|
||||
inference_opt = llm_restart(id);
|
||||
inference_opt = llm_restart(id, channel_cfg);
|
||||
}
|
||||
return inference_opt.value();
|
||||
}
|
||||
|
@ -210,54 +244,97 @@ class Bot {
|
|||
// Translate texts
|
||||
if (!texts.translated) {
|
||||
texts.please_wait = llm_translate_from_en(texts.please_wait);
|
||||
texts.initializing = llm_translate_from_en(texts.initializing);
|
||||
texts.loading = llm_translate_from_en(texts.loading);
|
||||
texts.model_missing = llm_translate_from_en(texts.model_missing);
|
||||
texts.thread_create_fail = llm_translate_from_en(texts.thread_create_fail);
|
||||
texts.timeout = llm_translate_from_en(texts.timeout);
|
||||
texts.translated = true;
|
||||
}
|
||||
// Inference for init cache
|
||||
if (!std::filesystem::exists("init_cache")) {
|
||||
LM::Inference llm(config.inference_model, llm_get_params());
|
||||
std::ofstream f("init_cache", std::ios::binary);
|
||||
// Add initial context
|
||||
std::string prompt;
|
||||
{
|
||||
// Read whole file
|
||||
std::ifstream f(config.prompt_file);
|
||||
if (!f) {
|
||||
// Clean up and abort on error
|
||||
std::cerr << "Failed to open prompt file." << std::endl;
|
||||
f.close();
|
||||
std::error_code ec;
|
||||
std::filesystem::remove("init_cache", ec);
|
||||
abort();
|
||||
// Build init caches
|
||||
std::string filename;
|
||||
for (const auto& [model_name, model_config] : model_configs) {
|
||||
//TODO: Add hashes to regenerate these as needed
|
||||
// Standard prompt
|
||||
filename = model_name+"_init_cache";
|
||||
if (model_config.is_non_instruct_mode_allowed() &&
|
||||
!std::filesystem::exists(filename) && config.prompt_file != "none") {
|
||||
std::cout << "Building init_cache for "+model_name+"..." << std::endl;
|
||||
LM::Inference llm(model_config.weight_path, llm_get_params());
|
||||
// Add initial context
|
||||
std::string prompt;
|
||||
{
|
||||
// Read whole file
|
||||
std::ifstream f(config.prompt_file);
|
||||
if (!f) {
|
||||
// Clean up and abort on error
|
||||
std::cerr << "Failed to open prompt file." << std::endl;
|
||||
abort();
|
||||
}
|
||||
std::ostringstream sstr;
|
||||
sstr << f.rdbuf();
|
||||
prompt = sstr.str();
|
||||
}
|
||||
std::ostringstream sstr;
|
||||
sstr << f.rdbuf();
|
||||
prompt = sstr.str();
|
||||
// Append
|
||||
using namespace fmt::literals;
|
||||
if (prompt.back() != '\n') prompt.push_back('\n');
|
||||
llm.append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress);
|
||||
// Serialize end result
|
||||
std::ofstream f(filename, std::ios::binary);
|
||||
llm.serialize(f);
|
||||
}
|
||||
// Instruct prompt
|
||||
filename = model_name+"_instruct_init_cache";
|
||||
if (model_config.is_instruct_mode_allowed() &&
|
||||
!std::filesystem::exists(filename) && config.instruct_prompt_file != "none") {
|
||||
std::cout << "Building instruct_init_cache for "+model_name+"..." << std::endl;
|
||||
LM::Inference llm(model_config.weight_path, llm_get_params());
|
||||
// Add initial context
|
||||
std::string prompt;
|
||||
{
|
||||
// Read whole file
|
||||
std::ifstream f(config.instruct_prompt_file);
|
||||
if (!f) {
|
||||
// Clean up and abort on error
|
||||
std::cerr << "Failed to open prompt file." << std::endl;
|
||||
abort();
|
||||
}
|
||||
std::ostringstream sstr;
|
||||
sstr << f.rdbuf();
|
||||
prompt = sstr.str();
|
||||
}
|
||||
// Append
|
||||
using namespace fmt::literals;
|
||||
if (prompt.back() != '\n') prompt.push_back('\n');
|
||||
llm.append(fmt::format(fmt::runtime(prompt+'\n'), "bot_name"_a=bot.me.username), show_console_progress);
|
||||
// Serialize end result
|
||||
std::ofstream f(filename, std::ios::binary);
|
||||
llm.serialize(f);
|
||||
}
|
||||
// Append
|
||||
using namespace fmt::literals;
|
||||
llm.append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username));
|
||||
// Serialize end result
|
||||
llm.serialize(f);
|
||||
}
|
||||
// Report complete init
|
||||
std::cout << "Init done!" << std::endl;
|
||||
}
|
||||
// Must run in llama thread
|
||||
void prompt_add_msg(const dpp::message& msg) {
|
||||
void prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Make sure message isn't too long
|
||||
if (msg.content.size() > 512) {
|
||||
return;
|
||||
}
|
||||
// Get inference
|
||||
auto& inference = llm_get_inference(msg.channel_id);
|
||||
auto& inference = llm_get_inference(msg.channel_id, channel_cfg);
|
||||
try {
|
||||
std::string prefix;
|
||||
// Instruct mode user prompt
|
||||
if (channel_cfg.instruct_mode) {
|
||||
inference.append('\n'+channel_cfg.model_config->user_prompt+"\n\n");
|
||||
} else {
|
||||
prefix = msg.author.username+": ";
|
||||
}
|
||||
// Format and append lines
|
||||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
|
||||
inference.append(prefix+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
|
||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
|
||||
timeout_exceeded = true;
|
||||
|
@ -268,28 +345,33 @@ class Bot {
|
|||
if (timeout_exceeded) inference.append("\n");
|
||||
}
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm_restart(inference);
|
||||
prompt_add_msg(msg);
|
||||
llm_restart(inference, channel_cfg);
|
||||
prompt_add_msg(msg, channel_cfg);
|
||||
}
|
||||
}
|
||||
// Must run in llama thread
|
||||
void prompt_add_trigger(dpp::snowflake id) {
|
||||
void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
auto& inference = llm_get_inference(id);
|
||||
auto& inference = llm_get_inference(id, channel_cfg);
|
||||
try {
|
||||
inference.append(bot.me.username+':', show_console_progress);
|
||||
if (channel_cfg.instruct_mode) {
|
||||
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
|
||||
} else {
|
||||
inference.append(bot.me.username+':', show_console_progress);
|
||||
}
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm_restart(inference);
|
||||
llm_restart(inference, channel_cfg);
|
||||
}
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void reply(dpp::snowflake id, dpp::message msg) {
|
||||
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
try { // Trigger LLM correctly
|
||||
prompt_add_trigger(id);
|
||||
try {
|
||||
// Trigger LLM correctly
|
||||
prompt_add_trigger(id, channel_cfg);
|
||||
// Get inference
|
||||
auto& inference = llm_get_inference(id);
|
||||
auto& inference = llm_get_inference(id, channel_cfg);
|
||||
// Run model
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
|
@ -315,16 +397,16 @@ class Bot {
|
|||
}
|
||||
}
|
||||
|
||||
bool attempt_reply(const dpp::message& msg) {
|
||||
bool attempt_reply(const dpp::message& msg, const BotChannelConfig& channel_cfg) {
|
||||
// Reply if message contains username, mention or ID
|
||||
if (msg.content.find(bot.me.username) != std::string::npos) {
|
||||
enqueue_reply(msg.channel_id);
|
||||
enqueue_reply(msg.channel_id, channel_cfg);
|
||||
return true;
|
||||
}
|
||||
// Reply if message references user
|
||||
for (const auto msg_id : my_messages) {
|
||||
if (msg.message_reference.message_id == msg_id) {
|
||||
enqueue_reply(msg.channel_id);
|
||||
enqueue_reply(msg.channel_id, channel_cfg);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -332,45 +414,40 @@ class Bot {
|
|||
return false;
|
||||
}
|
||||
|
||||
void enqueue_reply(dpp::snowflake id) {
|
||||
bot.message_create(dpp::message(id, texts.please_wait+" :thinking:"), [this, id] (const dpp::confirmation_callback_t& ccb) {
|
||||
void enqueue_reply(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
bot.message_create(dpp::message(id, texts.please_wait+" :thinking:"), [=, this] (const dpp::confirmation_callback_t& ccb) {
|
||||
if (ccb.is_error()) return;
|
||||
tPool.submit(std::bind(&Bot::reply, this, id, ccb.get<dpp::message>()));
|
||||
thread_pool.submit(std::bind(&Bot::reply, this, id, ccb.get<dpp::message>(), channel_cfg));
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
struct Configuration {
|
||||
std::string token,
|
||||
language = "EN",
|
||||
inference_model = "13B-ggml-model-quant.bin",
|
||||
translation_model = "13B-ggml-model-quant.bin",
|
||||
prompt_file = "prompt.txt";
|
||||
unsigned ctx_size = 1012,
|
||||
pool_size = 2,
|
||||
threads = 4,
|
||||
persistance = true;
|
||||
bool mlock = false;
|
||||
} config;
|
||||
Bot(decltype(config) cfg, decltype(model_configs) model_configs)
|
||||
: config(cfg), model_configs(model_configs), bot(cfg.token),
|
||||
language(cfg.language), db("database.sqlite3"),
|
||||
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
|
||||
// Initialize database
|
||||
db << "CREATE TABLE IF NOT EXISTS threads ("
|
||||
" id TEXT PRIMARY KEY NOT NULL,"
|
||||
" model TEXT,"
|
||||
" instruct_mode INTEGER,"
|
||||
" UNIQUE(id)"
|
||||
");";
|
||||
|
||||
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language),
|
||||
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
|
||||
// Configure llm_pool
|
||||
llm_pool.set_store_on_destruct(cfg.persistance);
|
||||
|
||||
// Initialize thread pool
|
||||
tPool.init();
|
||||
thread_pool.init();
|
||||
|
||||
// Prepare translator
|
||||
if (language != "EN") {
|
||||
tPool.submit([this] () {
|
||||
thread_pool.submit([this] () {
|
||||
std::cout << "Preparing translator..." << std::endl;
|
||||
translator = std::make_unique<Translator>(config.translation_model, llm_get_translation_params());
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare llm
|
||||
tPool.submit(std::bind(&Bot::llm_init, this));
|
||||
|
||||
// Configure bot
|
||||
bot.on_log(dpp::utility::cout_logger());
|
||||
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
|
||||
|
@ -378,6 +455,59 @@ public:
|
|||
// Set callbacks
|
||||
bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal
|
||||
std::cout << "Connected to Discord." << std::endl;
|
||||
// Register chat command once
|
||||
if (dpp::run_once<struct register_bot_commands>()) {
|
||||
for (const auto& [name, model] : model_configs) {
|
||||
// Create command
|
||||
dpp::slashcommand command(name, "Start a chat with me", bot.me.id);
|
||||
// Add instruct mode option
|
||||
if (model.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) {
|
||||
command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Weather to enable instruct mode", true));
|
||||
}
|
||||
// Register command
|
||||
bot.global_command_edit(command, [this, command] (const dpp::confirmation_callback_t& ccb) {
|
||||
if (ccb.is_error()) bot.global_command_create(command);
|
||||
});
|
||||
}
|
||||
}
|
||||
if (dpp::run_once<struct LM::Inference>()) {
|
||||
// Prepare llm
|
||||
thread_pool.submit(std::bind(&Bot::llm_init, this));
|
||||
}
|
||||
});
|
||||
bot.on_slashcommand([=, this](const dpp::slashcommand_t& event) {
|
||||
// Get model by name
|
||||
auto res = model_configs.find(event.command.get_command_name());
|
||||
if (res == model_configs.end()) {
|
||||
// Model does not exit, delete corresponding command
|
||||
bot.global_command_delete(event.command.id);
|
||||
return;
|
||||
}
|
||||
const auto& [model_name, model_config] = *res;
|
||||
// Get weather to enable instruct mode
|
||||
bool instruct_mode;
|
||||
if (model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) {
|
||||
instruct_mode = std::get<bool>(event.get_parameter("instruct_mode"));
|
||||
} else {
|
||||
instruct_mode = model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Force;
|
||||
}
|
||||
// Create thread
|
||||
bot.thread_create("Chat with "+model_name, event.command.channel_id, 1440, dpp::CHANNEL_PUBLIC_THREAD, true, 15,
|
||||
[this, event, instruct_mode, model_name = res->first] (const dpp::confirmation_callback_t& ccb) {
|
||||
// Check for error
|
||||
if (ccb.is_error()) {
|
||||
std::cout << "Thread creation failed: " << ccb.get_error().message << std::endl;
|
||||
event.reply(dpp::message(texts.thread_create_fail).set_flags(dpp::message_flags::m_ephemeral));
|
||||
return;
|
||||
}
|
||||
// Get thread
|
||||
const auto& thread = ccb.get<dpp::thread>();
|
||||
// Add thread to database
|
||||
db << "INSERT INTO threads (id, model, instruct_mode) VALUES (?, ?, ?);"
|
||||
<< std::to_string(thread.id) << model_name << instruct_mode;
|
||||
// Report success
|
||||
event.reply(dpp::message("Okay!").set_flags(dpp::message_flags::m_ephemeral));
|
||||
});
|
||||
});
|
||||
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
|
||||
// Update user cache
|
||||
|
@ -401,16 +531,49 @@ public:
|
|||
for (const auto& [user_id, user] : users) {
|
||||
str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
|
||||
}
|
||||
// Get channel config
|
||||
BotChannelConfig channel_cfg;
|
||||
// Attempt to find thread first...
|
||||
bool in_bot_thread = false;
|
||||
db << "SELECT model, instruct_mode FROM threads "
|
||||
"WHERE id = ?;"
|
||||
<< std::to_string(msg.channel_id)
|
||||
>> [&](const std::string& model_name, int instruct_mode) {
|
||||
in_bot_thread = true;
|
||||
channel_cfg.instruct_mode = instruct_mode;
|
||||
// Find model
|
||||
auto res = model_configs.find(model_name);
|
||||
if (res == model_configs.end()) {
|
||||
bot.message_create(dpp::message(msg.channel_id, texts.model_missing));
|
||||
return;
|
||||
}
|
||||
channel_cfg.model_name = &res->first;
|
||||
channel_cfg.model_config = &res->second;
|
||||
};
|
||||
// Otherwise just fall back to the default model config if allowed
|
||||
if (!in_bot_thread) {
|
||||
if (config.threads_only) return;
|
||||
channel_cfg.model_name = &config.default_inference_model;
|
||||
channel_cfg.model_config = config.default_inference_model_cfg;
|
||||
}
|
||||
// Append message
|
||||
tPool.submit(std::bind(&Bot::prompt_add_msg, this, msg));
|
||||
thread_pool.submit([=, this] () {
|
||||
prompt_add_msg(msg, channel_cfg);
|
||||
});
|
||||
// Handle message somehow...
|
||||
if (msg.content == "!trigger") {
|
||||
if (msg.content == "!store") {
|
||||
llm_pool.store_all(); //DEBUG
|
||||
# warning DEBUG CODE!!!
|
||||
} else if (in_bot_thread) {
|
||||
// Send a reply
|
||||
enqueue_reply(msg.channel_id, channel_cfg);
|
||||
} else if (msg.content == "!trigger") {
|
||||
// Delete message
|
||||
bot.message_delete(msg.id, msg.channel_id);
|
||||
// Send a reply
|
||||
enqueue_reply(msg.channel_id);
|
||||
enqueue_reply(msg.channel_id, channel_cfg);
|
||||
} else {
|
||||
attempt_reply(msg);
|
||||
attempt_reply(msg, channel_cfg);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Warning: " << e.what() << std::endl;
|
||||
|
@ -426,6 +589,31 @@ public:
|
|||
};
|
||||
|
||||
|
||||
bool parse_bool(const std::string& value) {
|
||||
if (value == "true")
|
||||
return true;
|
||||
if (value == "false")
|
||||
return false;
|
||||
std::cerr << "Failed to parse configuration file: Unknown bool (true/false): " << value << std::endl;
|
||||
exit(-4);
|
||||
}
|
||||
Bot::ModelConfig::InstructModePolicy parse_instruct_mode_policy(const std::string& value) {
|
||||
if (value == "allow")
|
||||
return Bot::ModelConfig::InstructModePolicy::Allow;
|
||||
if (value == "force")
|
||||
return Bot::ModelConfig::InstructModePolicy::Force;
|
||||
if (value == "forbid")
|
||||
return Bot::ModelConfig::InstructModePolicy::Forbid;
|
||||
std::cerr << "Failed to parse model configuration file: Unknown instruct mode policy (allow/force/forbid): " << value << std::endl;
|
||||
exit(-4);
|
||||
}
|
||||
|
||||
bool file_exists(const auto& p) {
|
||||
// Make sure we don't respond to some file that is actually called "none"...
|
||||
if (p == "none") return false;
|
||||
return std::filesystem::exists(p);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Check arguments
|
||||
if (argc < 2) {
|
||||
|
@ -433,12 +621,12 @@ int main(int argc, char **argv) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
// Parse configuration
|
||||
// Parse main configuration
|
||||
Bot::Configuration cfg;
|
||||
std::ifstream cfgf(argv[1]);
|
||||
if (!cfgf) {
|
||||
std::cerr << "Failed to open configuration file: " << argv[1] << std::endl;
|
||||
return -1;
|
||||
exit(-1);
|
||||
}
|
||||
for (std::string key; cfgf >> key;) {
|
||||
// Read value
|
||||
|
@ -451,12 +639,16 @@ int main(int argc, char **argv) {
|
|||
cfg.token = std::move(value);
|
||||
} else if (key == "language") {
|
||||
cfg.language = std::move(value);
|
||||
} else if (key == "inference_model") {
|
||||
cfg.inference_model = std::move(value);
|
||||
} else if (key == "default_inference_model") {
|
||||
cfg.default_inference_model = std::move(value);
|
||||
} else if (key == "translation_model") {
|
||||
cfg.translation_model = std::move(value);
|
||||
} else if (key == "prompt_file") {
|
||||
cfg.prompt_file = std::move(value);
|
||||
} else if (key == "instruct_prompt_file") {
|
||||
cfg.instruct_prompt_file = std::move(value);
|
||||
} else if (key == "models_dir") {
|
||||
cfg.models_dir = std::move(value);
|
||||
} else if (key == "pool_size") {
|
||||
cfg.pool_size = std::stoi(value);
|
||||
} else if (key == "threads") {
|
||||
|
@ -464,17 +656,113 @@ int main(int argc, char **argv) {
|
|||
} else if (key == "ctx_size") {
|
||||
cfg.ctx_size = std::stoi(value);
|
||||
} else if (key == "mlock") {
|
||||
cfg.mlock = (value=="true")?true:false;
|
||||
cfg.mlock = parse_bool(value);
|
||||
} else if (key == "threads_only") {
|
||||
cfg.threads_only = parse_bool(value);
|
||||
} else if (key == "persistance") {
|
||||
cfg.persistance = (value=="true")?true:false;
|
||||
cfg.persistance = parse_bool(value);
|
||||
} else if (!key.empty() && key[0] != '#') {
|
||||
std::cerr << "Failed to parse configuration file: Unknown key: " << key << std::endl;
|
||||
return -2;
|
||||
exit(-3);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse model configurations
|
||||
std::unordered_map<std::string, Bot::ModelConfig> models;
|
||||
std::filesystem::path models_dir(cfg.models_dir);
|
||||
bool allow_non_instruct = false;
|
||||
for (const auto& file : std::filesystem::directory_iterator(models_dir)) {
|
||||
// Check that file is model config
|
||||
if (file.is_directory() ||
|
||||
file.path().filename().extension() != ".txt") continue;
|
||||
// Get model name
|
||||
auto model_name = file.path().filename().string();
|
||||
model_name.erase(model_name.size()-4, 4);
|
||||
clean_command_name(model_name);
|
||||
// Parse model config
|
||||
Bot::ModelConfig model_cfg;
|
||||
std::ifstream cfgf(file.path());
|
||||
if (!cfgf) {
|
||||
std::cerr << "Failed to open model configuration file: " << file << std::endl;
|
||||
exit(-2);
|
||||
}
|
||||
std::string filename;
|
||||
for (std::string key; cfgf >> key;) {
|
||||
// Read value
|
||||
std::string value;
|
||||
std::getline(cfgf, value);
|
||||
// Erase all leading spaces
|
||||
while (!value.empty() && (value[0] == ' ' || value[0] == '\t')) value.erase(0, 1);
|
||||
// Check key and ignore comment lines
|
||||
if (key == "filename") {
|
||||
filename = std::move(value);
|
||||
} else if (key == "user_prompt") {
|
||||
model_cfg.user_prompt = std::move(value);
|
||||
} else if (key == "bot_prompt") {
|
||||
model_cfg.bot_prompt = std::move(value);
|
||||
} else if (key == "instruct_mode_policy") {
|
||||
model_cfg.instruct_mode_policy = parse_instruct_mode_policy(value);
|
||||
} else if (!key.empty() && key[0] != '#') {
|
||||
std::cerr << "Failed to parse model configuration file: Unknown key: " << key << std::endl;
|
||||
exit(-3);
|
||||
}
|
||||
}
|
||||
// Get full path
|
||||
model_cfg.weight_path = file.path().parent_path()/filename;
|
||||
// Safety checks
|
||||
if (filename.empty() || !file_exists(model_cfg.weight_path)) {
|
||||
std::cerr << "Failed to parse model configuration file: Invalid weight filename: " << model_name << std::endl;
|
||||
exit(-8);
|
||||
}
|
||||
if (model_cfg.instruct_mode_policy != Bot::ModelConfig::InstructModePolicy::Forbid &&
|
||||
(model_cfg.user_prompt.empty() || model_cfg.bot_prompt.empty())) {
|
||||
std::cerr << "Failed to parse model configuration file: Instruct mode allowed but user prompt and bot prompt not given: " << model_name << std::endl;
|
||||
exit(-9);
|
||||
}
|
||||
if (model_cfg.instruct_mode_policy != Bot::ModelConfig::InstructModePolicy::Force) {
|
||||
allow_non_instruct = true;
|
||||
}
|
||||
// Add model to list
|
||||
const auto& [stored_model_name, stored_model_cfg] = *models.emplace(std::move(model_name), std::move(model_cfg)).first;
|
||||
// Set model pointer in config
|
||||
if (stored_model_name == cfg.default_inference_model)
|
||||
cfg.default_inference_model_cfg = &stored_model_cfg;
|
||||
if (stored_model_name == cfg.translation_model)
|
||||
cfg.translation_model_cfg = &stored_model_cfg;
|
||||
}
|
||||
|
||||
// Safety checks
|
||||
if (cfg.language != "EN") {
|
||||
if (cfg.translation_model_cfg == nullptr) {
|
||||
std::cerr << "Translation model required for non-english language, but is invalid" << std::endl;
|
||||
exit(-5);
|
||||
}
|
||||
if (cfg.translation_model_cfg->instruct_mode_policy == Bot::ModelConfig::InstructModePolicy::Force) {
|
||||
std::cerr << "Translation model is required to not have instruct mode forced" << std::endl;
|
||||
exit(-10);
|
||||
}
|
||||
}
|
||||
if (allow_non_instruct && !file_exists(cfg.prompt_file)) {
|
||||
std::cerr << "Prompt file required when allowing non-instruct-mode use, but is invalid" << std::endl;
|
||||
exit(-11);
|
||||
}
|
||||
if (!cfg.threads_only) {
|
||||
if (cfg.default_inference_model_cfg == nullptr) {
|
||||
std::cerr << "Default model required if not threads only, but is invalid" << std::endl;
|
||||
exit(-6);
|
||||
}
|
||||
if (cfg.default_inference_model_cfg->instruct_mode_policy == Bot::ModelConfig::InstructModePolicy::Force) {
|
||||
std::cerr << "Default model must not have instruct mode forced if not threads only" << std::endl;
|
||||
exit(-7);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean model names in config
|
||||
clean_command_name(cfg.default_inference_model);
|
||||
clean_command_name(cfg.translation_model);
|
||||
|
||||
// Construct and configure bot
|
||||
Bot bot(cfg);
|
||||
Bot bot(cfg, models);
|
||||
|
||||
// Start bot
|
||||
bot.start();
|
||||
|
|
21
sqlite_modern_cpp/License.txt
Normal file
21
sqlite_modern_cpp/License.txt
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2017 aminroosta
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
1053
sqlite_modern_cpp/sqlite_modern_cpp.h
Normal file
1053
sqlite_modern_cpp/sqlite_modern_cpp.h
Normal file
File diff suppressed because it is too large
Load diff
60
sqlite_modern_cpp/sqlite_modern_cpp/errors.h
Normal file
60
sqlite_modern_cpp/sqlite_modern_cpp/errors.h
Normal file
|
@ -0,0 +1,60 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <sqlite3.h>
|
||||
|
||||
namespace sqlite {
|
||||
|
||||
class sqlite_exception: public std::runtime_error {
|
||||
public:
|
||||
sqlite_exception(const char* msg, std::string sql, int code = -1): runtime_error(msg), code(code), sql(sql) {}
|
||||
sqlite_exception(int code, std::string sql): runtime_error(sqlite3_errstr(code)), code(code), sql(sql) {}
|
||||
int get_code() const {return code & 0xFF;}
|
||||
int get_extended_code() const {return code;}
|
||||
std::string get_sql() const {return sql;}
|
||||
private:
|
||||
int code;
|
||||
std::string sql;
|
||||
};
|
||||
|
||||
namespace errors {
|
||||
//One more or less trivial derived error class for each SQLITE error.
|
||||
//Note the following are not errors so have no classes:
|
||||
//SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE
|
||||
//
|
||||
//Note these names are exact matches to the names of the SQLITE error codes.
|
||||
#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \
|
||||
class name: public sqlite_exception { using sqlite_exception::sqlite_exception; };\
|
||||
derived
|
||||
#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \
|
||||
class base ## _ ## sub: public base { using base::base; };
|
||||
#include "lists/error_codes.h"
|
||||
#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED
|
||||
#undef SQLITE_MODERN_CPP_ERROR_CODE
|
||||
|
||||
//Some additional errors are here for the C++ interface
|
||||
class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; };
|
||||
class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; };
|
||||
class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement
|
||||
class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; };
|
||||
|
||||
static void throw_sqlite_error(const int& error_code, const std::string &sql = "") {
|
||||
switch(error_code & 0xFF) {
|
||||
#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \
|
||||
case SQLITE_ ## NAME: switch(error_code) { \
|
||||
derived \
|
||||
default: throw name(error_code, sql); \
|
||||
}
|
||||
#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \
|
||||
case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql);
|
||||
#include "lists/error_codes.h"
|
||||
#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED
|
||||
#undef SQLITE_MODERN_CPP_ERROR_CODE
|
||||
default: throw sqlite_exception(error_code, sql);
|
||||
}
|
||||
}
|
||||
}
|
||||
namespace exceptions = errors;
|
||||
}
|
93
sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h
Normal file
93
sqlite_modern_cpp/sqlite_modern_cpp/lists/error_codes.h
Normal file
|
@ -0,0 +1,93 @@
|
|||
#if SQLITE_VERSION_NUMBER < 3010000
|
||||
#define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8))
|
||||
#define SQLITE_IOERR_AUTH (SQLITE_IOERR | (28<<8))
|
||||
#define SQLITE_AUTH_USER (SQLITE_AUTH | (1<<8))
|
||||
#endif
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(ERROR,error,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(INTERNAL,internal,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(PERM,perm,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(ABORT,abort,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(ABORT,ROLLBACK,abort,rollback)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(BUSY,busy,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,RECOVERY,busy,recovery)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,SNAPSHOT,busy,snapshot)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(LOCKED,locked,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(LOCKED,SHAREDCACHE,locked,sharedcache)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(NOMEM,nomem,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(READONLY,readonly,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(INTERRUPT,interrupt,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(IOERR,ioerr,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,READ,ioerr,read)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHORT_READ,ioerr,short_read)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,WRITE,ioerr,write)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSYNC,ioerr,fsync)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_FSYNC,ioerr,dir_fsync)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,TRUNCATE,ioerr,truncate)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSTAT,ioerr,fstat)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,UNLOCK,ioerr,unlock)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,RDLOCK,ioerr,rdlock)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE,ioerr,delete)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,BLOCKED,ioerr,blocked)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,NOMEM,ioerr,nomem)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,ACCESS,ioerr,access)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CHECKRESERVEDLOCK,ioerr,checkreservedlock)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,LOCK,ioerr,lock)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CLOSE,ioerr,close)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_CLOSE,ioerr,dir_close)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMOPEN,ioerr,shmopen)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMSIZE,ioerr,shmsize)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMLOCK,ioerr,shmlock)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMMAP,ioerr,shmmap)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SEEK,ioerr,seek)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE_NOENT,ioerr,delete_noent)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,MMAP,ioerr,mmap)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,GETTEMPPATH,ioerr,gettemppath)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CONVPATH,ioerr,convpath)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,VNODE,ioerr,vnode)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,AUTH,ioerr,auth)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(CORRUPT,corrupt,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CORRUPT,VTAB,corrupt,vtab)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(NOTFOUND,notfound,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(FULL,full,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(CANTOPEN,cantopen,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,NOTEMPDIR,cantopen,notempdir)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,ISDIR,cantopen,isdir)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,FULLPATH,cantopen,fullpath)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,CONVPATH,cantopen,convpath)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(PROTOCOL,protocol,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(EMPTY,empty,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(SCHEMA,schema,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(TOOBIG,toobig,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(CONSTRAINT,constraint,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,CHECK,constraint,check)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,COMMITHOOK,constraint,commithook)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FOREIGNKEY,constraint,foreignkey)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FUNCTION,constraint,function)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,NOTNULL,constraint,notnull)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,PRIMARYKEY,constraint,primarykey)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,TRIGGER,constraint,trigger)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,UNIQUE,constraint,unique)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,VTAB,constraint,vtab)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,ROWID,constraint,rowid)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(MISMATCH,mismatch,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(MISUSE,misuse,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(NOLFS,nolfs,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(AUTH,auth,
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(FORMAT,format,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(RANGE,range,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(NOTADB,notadb,)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(NOTICE,notice,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_WAL,notice,recover_wal)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_ROLLBACK,notice,recover_rollback)
|
||||
)
|
||||
SQLITE_MODERN_CPP_ERROR_CODE(WARNING,warning,
|
||||
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(WARNING,AUTOINDEX,warning,autoindex)
|
||||
)
|
101
sqlite_modern_cpp/sqlite_modern_cpp/log.h
Normal file
101
sqlite_modern_cpp/sqlite_modern_cpp/log.h
Normal file
|
@ -0,0 +1,101 @@
|
|||
#include "errors.h"
|
||||
|
||||
#include <sqlite3.h>
|
||||
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace sqlite {
|
||||
namespace detail {
|
||||
template<class>
|
||||
using void_t = void;
|
||||
template<class T, class = void>
|
||||
struct is_callable : std::false_type {};
|
||||
template<class Functor, class ...Arguments>
|
||||
struct is_callable<Functor(Arguments...), void_t<decltype(std::declval<Functor>()(std::declval<Arguments>()...))>> : std::true_type {};
|
||||
template<class Functor, class ...Functors>
|
||||
class FunctorOverload: public Functor, public FunctorOverload<Functors...> {
|
||||
public:
|
||||
template<class Functor1, class ...Remaining>
|
||||
FunctorOverload(Functor1 &&functor, Remaining &&... remaining):
|
||||
Functor(std::forward<Functor1>(functor)),
|
||||
FunctorOverload<Functors...>(std::forward<Remaining>(remaining)...) {}
|
||||
using Functor::operator();
|
||||
using FunctorOverload<Functors...>::operator();
|
||||
};
|
||||
template<class Functor>
|
||||
class FunctorOverload<Functor>: public Functor {
|
||||
public:
|
||||
template<class Functor1>
|
||||
FunctorOverload(Functor1 &&functor):
|
||||
Functor(std::forward<Functor1>(functor)) {}
|
||||
using Functor::operator();
|
||||
};
|
||||
template<class Functor>
|
||||
class WrapIntoFunctor: public Functor {
|
||||
public:
|
||||
template<class Functor1>
|
||||
WrapIntoFunctor(Functor1 &&functor):
|
||||
Functor(std::forward<Functor1>(functor)) {}
|
||||
using Functor::operator();
|
||||
};
|
||||
template<class ReturnType, class ...Arguments>
|
||||
class WrapIntoFunctor<ReturnType(*)(Arguments...)> {
|
||||
ReturnType(*ptr)(Arguments...);
|
||||
public:
|
||||
WrapIntoFunctor(ReturnType(*ptr)(Arguments...)): ptr(ptr) {}
|
||||
ReturnType operator()(Arguments... arguments) { return (*ptr)(std::forward<Arguments>(arguments)...); }
|
||||
};
|
||||
inline void store_error_log_data_pointer(std::shared_ptr<void> ptr) {
|
||||
static std::shared_ptr<void> stored;
|
||||
stored = std::move(ptr);
|
||||
}
|
||||
template<class T>
|
||||
std::shared_ptr<typename std::decay<T>::type> make_shared_inferred(T &&t) {
|
||||
return std::make_shared<typename std::decay<T>::type>(std::forward<T>(t));
|
||||
}
|
||||
}
|
||||
template<class Handler>
|
||||
typename std::enable_if<!detail::is_callable<Handler(const sqlite_exception&)>::value>::type
|
||||
error_log(Handler &&handler);
|
||||
template<class Handler>
|
||||
typename std::enable_if<detail::is_callable<Handler(const sqlite_exception&)>::value>::type
|
||||
error_log(Handler &&handler);
|
||||
template<class ...Handler>
|
||||
typename std::enable_if<sizeof...(Handler)>=2>::type
|
||||
error_log(Handler &&...handler) {
|
||||
return error_log(detail::FunctorOverload<detail::WrapIntoFunctor<typename std::decay<Handler>::type>...>(std::forward<Handler>(handler)...));
|
||||
}
|
||||
template<class Handler>
|
||||
typename std::enable_if<!detail::is_callable<Handler(const sqlite_exception&)>::value>::type
|
||||
error_log(Handler &&handler) {
|
||||
return error_log(std::forward<Handler>(handler), [](const sqlite_exception&) {});
|
||||
}
|
||||
template<class Handler>
|
||||
typename std::enable_if<detail::is_callable<Handler(const sqlite_exception&)>::value>::type
|
||||
error_log(Handler &&handler) {
|
||||
auto ptr = detail::make_shared_inferred([handler = std::forward<Handler>(handler)](int error_code, const char *errstr) mutable {
|
||||
switch(error_code & 0xFF) {
|
||||
#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \
|
||||
case SQLITE_ ## NAME: switch(error_code) { \
|
||||
derived \
|
||||
default: handler(errors::name(errstr, "", error_code)); \
|
||||
};break;
|
||||
#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \
|
||||
case SQLITE_ ## BASE ## _ ## SUB: \
|
||||
handler(errors::base ## _ ## sub(errstr, "", error_code)); \
|
||||
break;
|
||||
#include "lists/error_codes.h"
|
||||
#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED
|
||||
#undef SQLITE_MODERN_CPP_ERROR_CODE
|
||||
default: handler(sqlite_exception(errstr, "", error_code)); \
|
||||
}
|
||||
});
|
||||
|
||||
sqlite3_config(SQLITE_CONFIG_LOG, (void(*)(void*,int,const char*))[](void *functor, int error_code, const char *errstr) {
|
||||
(*static_cast<decltype(ptr.get())>(functor))(error_code, errstr);
|
||||
}, ptr.get());
|
||||
detail::store_error_log_data_pointer(std::move(ptr));
|
||||
}
|
||||
}
|
44
sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h
Normal file
44
sqlite_modern_cpp/sqlite_modern_cpp/sqlcipher.h
Normal file
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
|
||||
#ifndef SQLITE_HAS_CODEC
|
||||
#define SQLITE_HAS_CODEC
|
||||
#endif
|
||||
|
||||
#include "../sqlite_modern_cpp.h"
|
||||
|
||||
namespace sqlite {
|
||||
struct sqlcipher_config : public sqlite_config {
|
||||
std::string key;
|
||||
};
|
||||
|
||||
class sqlcipher_database : public database {
|
||||
public:
|
||||
sqlcipher_database(std::string db, const sqlcipher_config &config): database(db, config) {
|
||||
set_key(config.key);
|
||||
}
|
||||
|
||||
sqlcipher_database(std::u16string db, const sqlcipher_config &config): database(db, config) {
|
||||
set_key(config.key);
|
||||
}
|
||||
|
||||
void set_key(const std::string &key) {
|
||||
if(auto ret = sqlite3_key(_db.get(), key.data(), key.size()))
|
||||
errors::throw_sqlite_error(ret);
|
||||
}
|
||||
|
||||
void set_key(const std::string &key, const std::string &db_name) {
|
||||
if(auto ret = sqlite3_key_v2(_db.get(), db_name.c_str(), key.data(), key.size()))
|
||||
errors::throw_sqlite_error(ret);
|
||||
}
|
||||
|
||||
void rekey(const std::string &new_key) {
|
||||
if(auto ret = sqlite3_rekey(_db.get(), new_key.data(), new_key.size()))
|
||||
errors::throw_sqlite_error(ret);
|
||||
}
|
||||
|
||||
void rekey(const std::string &new_key, const std::string &db_name) {
|
||||
if(auto ret = sqlite3_rekey_v2(_db.get(), db_name.c_str(), new_key.data(), new_key.size()))
|
||||
errors::throw_sqlite_error(ret);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include<type_traits>
|
||||
|
||||
namespace sqlite {
|
||||
namespace utility {
|
||||
|
||||
template<typename> struct function_traits;
|
||||
|
||||
template <typename Function>
|
||||
struct function_traits : public function_traits<
|
||||
decltype(&std::remove_reference<Function>::type::operator())
|
||||
> { };
|
||||
|
||||
template <
|
||||
typename ClassType,
|
||||
typename ReturnType,
|
||||
typename... Arguments
|
||||
>
|
||||
struct function_traits<
|
||||
ReturnType(ClassType::*)(Arguments...) const
|
||||
> : function_traits<ReturnType(*)(Arguments...)> { };
|
||||
|
||||
/* support the non-const operator ()
|
||||
* this will work with user defined functors */
|
||||
template <
|
||||
typename ClassType,
|
||||
typename ReturnType,
|
||||
typename... Arguments
|
||||
>
|
||||
struct function_traits<
|
||||
ReturnType(ClassType::*)(Arguments...)
|
||||
> : function_traits<ReturnType(*)(Arguments...)> { };
|
||||
|
||||
template <
|
||||
typename ReturnType,
|
||||
typename... Arguments
|
||||
>
|
||||
struct function_traits<
|
||||
ReturnType(*)(Arguments...)
|
||||
> {
|
||||
typedef ReturnType result_type;
|
||||
|
||||
template <std::size_t Index>
|
||||
using argument = typename std::tuple_element<
|
||||
Index,
|
||||
std::tuple<Arguments...>
|
||||
>::type;
|
||||
|
||||
static const std::size_t arity = sizeof...(Arguments);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
|
||||
namespace sqlite {
|
||||
namespace utility {
|
||||
#ifdef __cpp_lib_uncaught_exceptions
|
||||
class UncaughtExceptionDetector {
|
||||
public:
|
||||
operator bool() {
|
||||
return count != std::uncaught_exceptions();
|
||||
}
|
||||
private:
|
||||
int count = std::uncaught_exceptions();
|
||||
};
|
||||
#else
|
||||
class UncaughtExceptionDetector {
|
||||
public:
|
||||
operator bool() {
|
||||
return std::uncaught_exception();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
}
|
||||
}
|
42
sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h
Normal file
42
sqlite_modern_cpp/sqlite_modern_cpp/utility/utf16_utf8.h
Normal file
|
@ -0,0 +1,42 @@
|
|||
#pragma once
|
||||
|
||||
#include <locale>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
|
||||
#include "../errors.h"
|
||||
|
||||
namespace sqlite {
|
||||
namespace utility {
|
||||
inline std::string utf16_to_utf8(const std::u16string &input) {
|
||||
struct : std::codecvt<char16_t, char, std::mbstate_t> {
|
||||
} codecvt;
|
||||
std::mbstate_t state{};
|
||||
std::string result((std::max)(input.size() * 3 / 2, std::size_t(4)), '\0');
|
||||
const char16_t *remaining_input = input.data();
|
||||
std::size_t produced_output = 0;
|
||||
while(true) {
|
||||
char *used_output;
|
||||
switch(codecvt.out(state, remaining_input, &input[input.size()],
|
||||
remaining_input, &result[produced_output],
|
||||
&result[result.size() - 1] + 1, used_output)) {
|
||||
case std::codecvt_base::ok:
|
||||
result.resize(used_output - result.data());
|
||||
return result;
|
||||
case std::codecvt_base::noconv:
|
||||
// This should be unreachable
|
||||
case std::codecvt_base::error:
|
||||
throw errors::invalid_utf16("Invalid UTF-16 input", "");
|
||||
case std::codecvt_base::partial:
|
||||
if(used_output == result.data() + produced_output)
|
||||
throw errors::invalid_utf16("Unexpected end of input", "");
|
||||
produced_output = used_output - result.data();
|
||||
result.resize(
|
||||
result.size()
|
||||
+ (std::max)((&input[input.size()] - remaining_input) * 3 / 2,
|
||||
std::ptrdiff_t(4)));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace utility
|
||||
} // namespace sqlite
|
201
sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h
Normal file
201
sqlite_modern_cpp/sqlite_modern_cpp/utility/variant.h
Normal file
|
@ -0,0 +1,201 @@
|
|||
#pragma once
|
||||
|
||||
#include "../errors.h"
|
||||
#include <sqlite3.h>
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
namespace sqlite::utility {
|
||||
template<typename ...Options>
|
||||
struct VariantFirstNullable {
|
||||
using type = void;
|
||||
};
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstNullable<T, Options...> {
|
||||
using type = typename VariantFirstNullable<Options...>::type;
|
||||
};
|
||||
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstNullable<std::optional<T>, Options...> {
|
||||
using type = std::optional<T>;
|
||||
};
|
||||
#endif
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstNullable<std::unique_ptr<T>, Options...> {
|
||||
using type = std::unique_ptr<T>;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstNullable<std::nullptr_t, Options...> {
|
||||
using type = std::nullptr_t;
|
||||
};
|
||||
template<typename Callback, typename ...Options>
|
||||
inline void variant_select_null(Callback&&callback) {
|
||||
if constexpr(std::is_same_v<typename VariantFirstNullable<Options...>::type, void>) {
|
||||
throw errors::mismatch("NULL is unsupported by this variant.", "", SQLITE_MISMATCH);
|
||||
} else {
|
||||
std::forward<Callback>(callback)(typename VariantFirstNullable<Options...>::type());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ...Options>
|
||||
struct VariantFirstIntegerable {
|
||||
using type = void;
|
||||
};
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstIntegerable<T, Options...> {
|
||||
using type = typename VariantFirstIntegerable<Options...>::type;
|
||||
};
|
||||
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstIntegerable<std::optional<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstIntegerable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstIntegerable<Options...>::type>;
|
||||
};
|
||||
#endif
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstIntegerable<std::enable_if_t<std::is_same_v<typename VariantFirstIntegerable<T, Options...>::type, T>>, std::unique_ptr<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstIntegerable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstIntegerable<Options...>::type>;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstIntegerable<int, Options...> {
|
||||
using type = int;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstIntegerable<sqlite_int64, Options...> {
|
||||
using type = sqlite_int64;
|
||||
};
|
||||
template<typename Callback, typename ...Options>
|
||||
inline auto variant_select_integer(Callback&&callback) {
|
||||
if constexpr(std::is_same_v<typename VariantFirstIntegerable<Options...>::type, void>) {
|
||||
throw errors::mismatch("Integer is unsupported by this variant.", "", SQLITE_MISMATCH);
|
||||
} else {
|
||||
std::forward<Callback>(callback)(typename VariantFirstIntegerable<Options...>::type());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ...Options>
|
||||
struct VariantFirstFloatable {
|
||||
using type = void;
|
||||
};
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstFloatable<T, Options...> {
|
||||
using type = typename VariantFirstFloatable<Options...>::type;
|
||||
};
|
||||
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstFloatable<std::optional<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstFloatable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstFloatable<Options...>::type>;
|
||||
};
|
||||
#endif
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstFloatable<std::unique_ptr<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstFloatable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstFloatable<Options...>::type>;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstFloatable<float, Options...> {
|
||||
using type = float;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstFloatable<double, Options...> {
|
||||
using type = double;
|
||||
};
|
||||
template<typename Callback, typename ...Options>
|
||||
inline auto variant_select_float(Callback&&callback) {
|
||||
if constexpr(std::is_same_v<typename VariantFirstFloatable<Options...>::type, void>) {
|
||||
throw errors::mismatch("Real is unsupported by this variant.", "", SQLITE_MISMATCH);
|
||||
} else {
|
||||
std::forward<Callback>(callback)(typename VariantFirstFloatable<Options...>::type());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ...Options>
|
||||
struct VariantFirstTextable {
|
||||
using type = void;
|
||||
};
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstTextable<T, Options...> {
|
||||
using type = typename VariantFirstTextable<void, Options...>::type;
|
||||
};
|
||||
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstTextable<std::optional<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstTextable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstTextable<Options...>::type>;
|
||||
};
|
||||
#endif
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstTextable<std::unique_ptr<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstTextable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstTextable<Options...>::type>;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstTextable<std::string, Options...> {
|
||||
using type = std::string;
|
||||
};
|
||||
template<typename ...Options>
|
||||
struct VariantFirstTextable<std::u16string, Options...> {
|
||||
using type = std::u16string;
|
||||
};
|
||||
template<typename Callback, typename ...Options>
|
||||
inline void variant_select_text(Callback&&callback) {
|
||||
if constexpr(std::is_same_v<typename VariantFirstTextable<Options...>::type, void>) {
|
||||
throw errors::mismatch("Text is unsupported by this variant.", "", SQLITE_MISMATCH);
|
||||
} else {
|
||||
std::forward<Callback>(callback)(typename VariantFirstTextable<Options...>::type());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ...Options>
|
||||
struct VariantFirstBlobable {
|
||||
using type = void;
|
||||
};
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstBlobable<T, Options...> {
|
||||
using type = typename VariantFirstBlobable<Options...>::type;
|
||||
};
|
||||
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstBlobable<std::optional<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstBlobable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstBlobable<Options...>::type>;
|
||||
};
|
||||
#endif
|
||||
template<typename T, typename ...Options>
|
||||
struct VariantFirstBlobable<std::unique_ptr<T>, Options...> {
|
||||
using type = std::conditional_t<std::is_same_v<typename VariantFirstBlobable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstBlobable<Options...>::type>;
|
||||
};
|
||||
template<typename T, typename A, typename ...Options>
|
||||
struct VariantFirstBlobable<std::enable_if_t<std::is_pod_v<T>>, std::vector<T, A>, Options...> {
|
||||
using type = std::vector<T, A>;
|
||||
};
|
||||
template<typename Callback, typename ...Options>
|
||||
inline auto variant_select_blob(Callback&&callback) {
|
||||
if constexpr(std::is_same_v<typename VariantFirstBlobable<Options...>::type, void>) {
|
||||
throw errors::mismatch("Blob is unsupported by this variant.", "", SQLITE_MISMATCH);
|
||||
} else {
|
||||
std::forward<Callback>(callback)(typename VariantFirstBlobable<Options...>::type());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ...Options>
|
||||
inline auto variant_select(int type) {
|
||||
return [type](auto &&callback) {
|
||||
using Callback = decltype(callback);
|
||||
switch(type) {
|
||||
case SQLITE_NULL:
|
||||
variant_select_null<Callback, Options...>(std::forward<Callback>(callback));
|
||||
break;
|
||||
case SQLITE_INTEGER:
|
||||
variant_select_integer<Callback, Options...>(std::forward<Callback>(callback));
|
||||
break;
|
||||
case SQLITE_FLOAT:
|
||||
variant_select_float<Callback, Options...>(std::forward<Callback>(callback));
|
||||
break;
|
||||
case SQLITE_TEXT:
|
||||
variant_select_text<Callback, Options...>(std::forward<Callback>(callback));
|
||||
break;
|
||||
case SQLITE_BLOB:
|
||||
variant_select_blob<Callback, Options...>(std::forward<Callback>(callback));
|
||||
break;
|
||||
default:;
|
||||
/* assert(false); */
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue