1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Added tons of features including thread creation

This commit is contained in:
niansa 2023-04-25 15:55:46 +02:00
parent a0edf72a43
commit bf568c3c8a
18 changed files with 2127 additions and 122 deletions

View file

@ -11,7 +11,7 @@ add_subdirectory(thread-pool)
add_subdirectory(fmt)
add_executable(discord_llama main.cpp)
target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool)
target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool sqlite3)
install(TARGETS discord_llama
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

@ -1 +1 @@
Subproject commit 6ba351b1173539254b2e641d19ca77c459a087a2
Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707

View file

@ -1,12 +1,18 @@
token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg
# The following parameters are set to their defaults here and can be ommited
models_dir models
language EN
inference_model 13B-ggml-model-quant.bin
translation_model 13B-ggml-model-quant.bin
prompt_file prompt.txt
threads_only true
default_inference_model 13B-vanilla
translation_model none
prompt_file none
instruct_prompt_file none
persistance true
mlock false
pool_size 2
threads 4
persistance true
ctx_size 1012

View file

@ -0,0 +1,2 @@
filename 13B-ggml-model-quant.bin
instruct_mode_policy forbid

View file

@ -0,0 +1,4 @@
filename ggml-vicuna-13b-1.0-q4_0.bin
instruct_mode_policy force
user_prompt ### Human:
bot_prompt ### Assistant:

View file

@ -0,0 +1,4 @@
filename ggml-vicuna-13b-1.1-q4_0.bin
instruct_mode_policy force
user_prompt HUMAN:
bot_prompt ASSISTANT:

View file

@ -0,0 +1,4 @@
filename gpt4all-lora-unfiltered-quantized.bin
instruct_mode_policy allow
user_prompt ### Instruction:
bot_prompt ### Response:

520
main.cpp
View file

@ -1,4 +1,5 @@
#include "Timer.hpp"
#include "sqlite_modern_cpp/sqlite_modern_cpp.h"
#include <string>
#include <string_view>
@ -48,9 +49,23 @@ void str_replace_in_place(std::string& subject, std::string_view search,
}
}
static
void clean_command_name(std::string& value) {
for (auto& c : value) {
if (c == '.') c = '_';
if (isalpha(c)) c = tolower(c);
}
}
[[nodiscard]] static
std::string clean_command_name(std::string_view input) {
std::string fres(input);
clean_command_name(fres);
return fres;
}
class Bot {
ThreadPool tPool{1};
ThreadPool thread_pool{1};
Timer last_message_timer;
std::shared_ptr<bool> stopping;
LM::InferencePool llm_pool;
@ -58,47 +73,64 @@ class Bot {
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::thread::id llm_tid;
sqlite::database db;
std::string_view language;
dpp::cluster bot;
public:
struct ModelConfig {
std::string weight_path,
user_prompt,
bot_prompt;
enum class InstructModePolicy {
Allow = 0b11,
Force = 0b10,
Forbid = 0b01
} instruct_mode_policy = InstructModePolicy::Allow;
bool is_instruct_mode_allowed() const {
return static_cast<unsigned>(instruct_mode_policy) & 0b10;
}
bool is_non_instruct_mode_allowed() const {
return static_cast<unsigned>(instruct_mode_policy) & 0b01;
}
};
struct BotChannelConfig {
const std::string *model_name;
const ModelConfig *model_config;
bool instruct_mode = false;
};
struct Configuration {
std::string token,
language = "EN",
default_inference_model = "13B-vanilla",
translation_model = "none",
prompt_file = "none",
instruct_prompt_file = "none",
models_dir = "models";
unsigned ctx_size = 1012,
pool_size = 2,
threads = 4,
persistance = true;
bool mlock = false,
threads_only = true;
const ModelConfig *default_inference_model_cfg = nullptr,
*translation_model_cfg = nullptr;
};
private:
const Configuration& config;
const std::unordered_map<std::string, ModelConfig>& model_configs;
struct Texts {
std::string please_wait = "Please wait...",
loading = "Loading...",
initializing = "Initializing...",
thread_create_fail = "Error: I couldn't create a thread here. Do I have enough permissions?",
model_missing = "Error: The model that was used in this thread could no longer be found.",
timeout = "Error: Timeout";
bool translated = false;
} texts;
inline static
std::string create_text_progress_indicator(uint8_t percentage) {
static constexpr uint8_t divisor = 3,
width = 100 / divisor;
// Progress bar percentage lookup
const static auto indicator_lookup = [] () consteval {
std::array<uint8_t, 101> fres;
for (uint8_t it = 0; it != 101; it++) {
fres[it] = it / divisor;
}
return fres;
}();
// Initialize string
std::string fres;
fres.resize(width+4);
fres[0] = '`';
fres[1] = '[';
// Append progress
const uint8_t bars = indicator_lookup[percentage];
for (uint8_t it = 0; it != width; it++) {
if (it < bars) fres[it+2] = '#';
else fres[it+2] = ' ';
}
// Finalize and return string
fres[width+2] = ']';
fres[width+3] = '`';
return fres;
}
inline static
bool show_console_progress(float progress) {
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
@ -178,27 +210,29 @@ class Bot {
}
// Must run in llama thread
void llm_restart(LM::Inference& inference) {
// Deserialize init cache
std::ifstream f("init_cache", std::ios::binary);
void llm_restart(LM::Inference& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Deserialize init cache if not instruct mode without prompt file
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return;
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
inference.deserialize(f);
}
// Must run in llama thread
LM::Inference &llm_restart(dpp::snowflake id) {
LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get or create inference
auto& inference = llm_pool.get_or_create_inference(id, config.inference_model, llm_get_params());
llm_restart(inference);
auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params());
llm_restart(inference, channel_cfg);
return inference;
}
// Must run in llama thread
LM::Inference &llm_get_inference(dpp::snowflake id) {
LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
auto inference_opt = llm_pool.get_inference(id);
if (!inference_opt.has_value()) {
// Start new inference
inference_opt = llm_restart(id);
inference_opt = llm_restart(id, channel_cfg);
}
return inference_opt.value();
}
@ -210,54 +244,97 @@ class Bot {
// Translate texts
if (!texts.translated) {
texts.please_wait = llm_translate_from_en(texts.please_wait);
texts.initializing = llm_translate_from_en(texts.initializing);
texts.loading = llm_translate_from_en(texts.loading);
texts.model_missing = llm_translate_from_en(texts.model_missing);
texts.thread_create_fail = llm_translate_from_en(texts.thread_create_fail);
texts.timeout = llm_translate_from_en(texts.timeout);
texts.translated = true;
}
// Inference for init cache
if (!std::filesystem::exists("init_cache")) {
LM::Inference llm(config.inference_model, llm_get_params());
std::ofstream f("init_cache", std::ios::binary);
// Add initial context
std::string prompt;
{
// Read whole file
std::ifstream f(config.prompt_file);
if (!f) {
// Clean up and abort on error
std::cerr << "Failed to open prompt file." << std::endl;
f.close();
std::error_code ec;
std::filesystem::remove("init_cache", ec);
abort();
// Build init caches
std::string filename;
for (const auto& [model_name, model_config] : model_configs) {
//TODO: Add hashes to regenerate these as needed
// Standard prompt
filename = model_name+"_init_cache";
if (model_config.is_non_instruct_mode_allowed() &&
!std::filesystem::exists(filename) && config.prompt_file != "none") {
std::cout << "Building init_cache for "+model_name+"..." << std::endl;
LM::Inference llm(model_config.weight_path, llm_get_params());
// Add initial context
std::string prompt;
{
// Read whole file
std::ifstream f(config.prompt_file);
if (!f) {
// Clean up and abort on error
std::cerr << "Failed to open prompt file." << std::endl;
abort();
}
std::ostringstream sstr;
sstr << f.rdbuf();
prompt = sstr.str();
}
std::ostringstream sstr;
sstr << f.rdbuf();
prompt = sstr.str();
// Append
using namespace fmt::literals;
if (prompt.back() != '\n') prompt.push_back('\n');
llm.append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress);
// Serialize end result
std::ofstream f(filename, std::ios::binary);
llm.serialize(f);
}
// Instruct prompt
filename = model_name+"_instruct_init_cache";
if (model_config.is_instruct_mode_allowed() &&
!std::filesystem::exists(filename) && config.instruct_prompt_file != "none") {
std::cout << "Building instruct_init_cache for "+model_name+"..." << std::endl;
LM::Inference llm(model_config.weight_path, llm_get_params());
// Add initial context
std::string prompt;
{
// Read whole file
std::ifstream f(config.instruct_prompt_file);
if (!f) {
// Clean up and abort on error
std::cerr << "Failed to open prompt file." << std::endl;
abort();
}
std::ostringstream sstr;
sstr << f.rdbuf();
prompt = sstr.str();
}
// Append
using namespace fmt::literals;
if (prompt.back() != '\n') prompt.push_back('\n');
llm.append(fmt::format(fmt::runtime(prompt+'\n'), "bot_name"_a=bot.me.username), show_console_progress);
// Serialize end result
std::ofstream f(filename, std::ios::binary);
llm.serialize(f);
}
// Append
using namespace fmt::literals;
llm.append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username));
// Serialize end result
llm.serialize(f);
}
// Report complete init
std::cout << "Init done!" << std::endl;
}
// Must run in llama thread
void prompt_add_msg(const dpp::message& msg) {
void prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Make sure message isn't too long
if (msg.content.size() > 512) {
return;
}
// Get inference
auto& inference = llm_get_inference(msg.channel_id);
auto& inference = llm_get_inference(msg.channel_id, channel_cfg);
try {
std::string prefix;
// Instruct mode user prompt
if (channel_cfg.instruct_mode) {
inference.append('\n'+channel_cfg.model_config->user_prompt+"\n\n");
} else {
prefix = msg.author.username+": ";
}
// Format and append lines
for (const auto line : str_split(msg.content, '\n')) {
Timer timeout;
bool timeout_exceeded = false;
inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
inference.append(prefix+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
if (timeout.get<std::chrono::minutes>() > 1) {
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
timeout_exceeded = true;
@ -268,28 +345,33 @@ class Bot {
if (timeout_exceeded) inference.append("\n");
}
} catch (const LM::Inference::ContextLengthException&) {
llm_restart(inference);
prompt_add_msg(msg);
llm_restart(inference, channel_cfg);
prompt_add_msg(msg, channel_cfg);
}
}
// Must run in llama thread
void prompt_add_trigger(dpp::snowflake id) {
void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
auto& inference = llm_get_inference(id);
auto& inference = llm_get_inference(id, channel_cfg);
try {
inference.append(bot.me.username+':', show_console_progress);
if (channel_cfg.instruct_mode) {
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
} else {
inference.append(bot.me.username+':', show_console_progress);
}
} catch (const LM::Inference::ContextLengthException&) {
llm_restart(inference);
llm_restart(inference, channel_cfg);
}
}
// Must run in llama thread
void reply(dpp::snowflake id, dpp::message msg) {
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
try { // Trigger LLM correctly
prompt_add_trigger(id);
try {
// Trigger LLM correctly
prompt_add_trigger(id, channel_cfg);
// Get inference
auto& inference = llm_get_inference(id);
auto& inference = llm_get_inference(id, channel_cfg);
// Run model
Timer timeout;
bool timeout_exceeded = false;
@ -315,16 +397,16 @@ class Bot {
}
}
bool attempt_reply(const dpp::message& msg) {
bool attempt_reply(const dpp::message& msg, const BotChannelConfig& channel_cfg) {
// Reply if message contains username, mention or ID
if (msg.content.find(bot.me.username) != std::string::npos) {
enqueue_reply(msg.channel_id);
enqueue_reply(msg.channel_id, channel_cfg);
return true;
}
// Reply if message references user
for (const auto msg_id : my_messages) {
if (msg.message_reference.message_id == msg_id) {
enqueue_reply(msg.channel_id);
enqueue_reply(msg.channel_id, channel_cfg);
return true;
}
}
@ -332,45 +414,40 @@ class Bot {
return false;
}
void enqueue_reply(dpp::snowflake id) {
bot.message_create(dpp::message(id, texts.please_wait+" :thinking:"), [this, id] (const dpp::confirmation_callback_t& ccb) {
void enqueue_reply(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
bot.message_create(dpp::message(id, texts.please_wait+" :thinking:"), [=, this] (const dpp::confirmation_callback_t& ccb) {
if (ccb.is_error()) return;
tPool.submit(std::bind(&Bot::reply, this, id, ccb.get<dpp::message>()));
thread_pool.submit(std::bind(&Bot::reply, this, id, ccb.get<dpp::message>(), channel_cfg));
});
}
public:
struct Configuration {
std::string token,
language = "EN",
inference_model = "13B-ggml-model-quant.bin",
translation_model = "13B-ggml-model-quant.bin",
prompt_file = "prompt.txt";
unsigned ctx_size = 1012,
pool_size = 2,
threads = 4,
persistance = true;
bool mlock = false;
} config;
Bot(decltype(config) cfg, decltype(model_configs) model_configs)
: config(cfg), model_configs(model_configs), bot(cfg.token),
language(cfg.language), db("database.sqlite3"),
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
// Initialize database
db << "CREATE TABLE IF NOT EXISTS threads ("
" id TEXT PRIMARY KEY NOT NULL,"
" model TEXT,"
" instruct_mode INTEGER,"
" UNIQUE(id)"
");";
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language),
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
// Configure llm_pool
llm_pool.set_store_on_destruct(cfg.persistance);
// Initialize thread pool
tPool.init();
thread_pool.init();
// Prepare translator
if (language != "EN") {
tPool.submit([this] () {
thread_pool.submit([this] () {
std::cout << "Preparing translator..." << std::endl;
translator = std::make_unique<Translator>(config.translation_model, llm_get_translation_params());
});
}
// Prepare llm
tPool.submit(std::bind(&Bot::llm_init, this));
// Configure bot
bot.on_log(dpp::utility::cout_logger());
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
@ -378,6 +455,59 @@ public:
// Set callbacks
bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal
std::cout << "Connected to Discord." << std::endl;
// Register chat command once
if (dpp::run_once<struct register_bot_commands>()) {
for (const auto& [name, model] : model_configs) {
// Create command
dpp::slashcommand command(name, "Start a chat with me", bot.me.id);
// Add instruct mode option
if (model.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) {
command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Weather to enable instruct mode", true));
}
// Register command
bot.global_command_edit(command, [this, command] (const dpp::confirmation_callback_t& ccb) {
if (ccb.is_error()) bot.global_command_create(command);
});
}
}
if (dpp::run_once<struct LM::Inference>()) {
// Prepare llm
thread_pool.submit(std::bind(&Bot::llm_init, this));
}
});
bot.on_slashcommand([=, this](const dpp::slashcommand_t& event) {
// Get model by name
auto res = model_configs.find(event.command.get_command_name());
if (res == model_configs.end()) {
// Model does not exit, delete corresponding command
bot.global_command_delete(event.command.id);
return;
}
const auto& [model_name, model_config] = *res;
// Get weather to enable instruct mode
bool instruct_mode;
if (model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) {
instruct_mode = std::get<bool>(event.get_parameter("instruct_mode"));
} else {
instruct_mode = model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Force;
}
// Create thread
bot.thread_create("Chat with "+model_name, event.command.channel_id, 1440, dpp::CHANNEL_PUBLIC_THREAD, true, 15,
[this, event, instruct_mode, model_name = res->first] (const dpp::confirmation_callback_t& ccb) {
// Check for error
if (ccb.is_error()) {
std::cout << "Thread creation failed: " << ccb.get_error().message << std::endl;
event.reply(dpp::message(texts.thread_create_fail).set_flags(dpp::message_flags::m_ephemeral));
return;
}
// Get thread
const auto& thread = ccb.get<dpp::thread>();
// Add thread to database
db << "INSERT INTO threads (id, model, instruct_mode) VALUES (?, ?, ?);"
<< std::to_string(thread.id) << model_name << instruct_mode;
// Report success
event.reply(dpp::message("Okay!").set_flags(dpp::message_flags::m_ephemeral));
});
});
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
// Update user cache
@ -401,16 +531,49 @@ public:
for (const auto& [user_id, user] : users) {
str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
}
// Get channel config
BotChannelConfig channel_cfg;
// Attempt to find thread first...
bool in_bot_thread = false;
db << "SELECT model, instruct_mode FROM threads "
"WHERE id = ?;"
<< std::to_string(msg.channel_id)
>> [&](const std::string& model_name, int instruct_mode) {
in_bot_thread = true;
channel_cfg.instruct_mode = instruct_mode;
// Find model
auto res = model_configs.find(model_name);
if (res == model_configs.end()) {
bot.message_create(dpp::message(msg.channel_id, texts.model_missing));
return;
}
channel_cfg.model_name = &res->first;
channel_cfg.model_config = &res->second;
};
// Otherwise just fall back to the default model config if allowed
if (!in_bot_thread) {
if (config.threads_only) return;
channel_cfg.model_name = &config.default_inference_model;
channel_cfg.model_config = config.default_inference_model_cfg;
}
// Append message
tPool.submit(std::bind(&Bot::prompt_add_msg, this, msg));
thread_pool.submit([=, this] () {
prompt_add_msg(msg, channel_cfg);
});
// Handle message somehow...
if (msg.content == "!trigger") {
if (msg.content == "!store") {
llm_pool.store_all(); //DEBUG
# warning DEBUG CODE!!!
} else if (in_bot_thread) {
// Send a reply
enqueue_reply(msg.channel_id, channel_cfg);
} else if (msg.content == "!trigger") {
// Delete message
bot.message_delete(msg.id, msg.channel_id);
// Send a reply
enqueue_reply(msg.channel_id);
enqueue_reply(msg.channel_id, channel_cfg);
} else {
attempt_reply(msg);
attempt_reply(msg, channel_cfg);
}
} catch (const std::exception& e) {
std::cerr << "Warning: " << e.what() << std::endl;
@ -426,6 +589,31 @@ public:
};
bool parse_bool(const std::string& value) {
if (value == "true")
return true;
if (value == "false")
return false;
std::cerr << "Failed to parse configuration file: Unknown bool (true/false): " << value << std::endl;
exit(-4);
}
Bot::ModelConfig::InstructModePolicy parse_instruct_mode_policy(const std::string& value) {
if (value == "allow")
return Bot::ModelConfig::InstructModePolicy::Allow;
if (value == "force")
return Bot::ModelConfig::InstructModePolicy::Force;
if (value == "forbid")
return Bot::ModelConfig::InstructModePolicy::Forbid;
std::cerr << "Failed to parse model configuration file: Unknown instruct mode policy (allow/force/forbid): " << value << std::endl;
exit(-4);
}
bool file_exists(const auto& p) {
// Make sure we don't respond to some file that is actually called "none"...
if (p == "none") return false;
return std::filesystem::exists(p);
}
int main(int argc, char **argv) {
// Check arguments
if (argc < 2) {
@ -433,12 +621,12 @@ int main(int argc, char **argv) {
return -1;
}
// Parse configuration
// Parse main configuration
Bot::Configuration cfg;
std::ifstream cfgf(argv[1]);
if (!cfgf) {
std::cerr << "Failed to open configuration file: " << argv[1] << std::endl;
return -1;
exit(-1);
}
for (std::string key; cfgf >> key;) {
// Read value
@ -451,12 +639,16 @@ int main(int argc, char **argv) {
cfg.token = std::move(value);
} else if (key == "language") {
cfg.language = std::move(value);
} else if (key == "inference_model") {
cfg.inference_model = std::move(value);
} else if (key == "default_inference_model") {
cfg.default_inference_model = std::move(value);
} else if (key == "translation_model") {
cfg.translation_model = std::move(value);
} else if (key == "prompt_file") {
cfg.prompt_file = std::move(value);
} else if (key == "instruct_prompt_file") {
cfg.instruct_prompt_file = std::move(value);
} else if (key == "models_dir") {
cfg.models_dir = std::move(value);
} else if (key == "pool_size") {
cfg.pool_size = std::stoi(value);
} else if (key == "threads") {
@ -464,17 +656,113 @@ int main(int argc, char **argv) {
} else if (key == "ctx_size") {
cfg.ctx_size = std::stoi(value);
} else if (key == "mlock") {
cfg.mlock = (value=="true")?true:false;
cfg.mlock = parse_bool(value);
} else if (key == "threads_only") {
cfg.threads_only = parse_bool(value);
} else if (key == "persistance") {
cfg.persistance = (value=="true")?true:false;
cfg.persistance = parse_bool(value);
} else if (!key.empty() && key[0] != '#') {
std::cerr << "Failed to parse configuration file: Unknown key: " << key << std::endl;
return -2;
exit(-3);
}
}
// Parse model configurations
std::unordered_map<std::string, Bot::ModelConfig> models;
std::filesystem::path models_dir(cfg.models_dir);
bool allow_non_instruct = false;
for (const auto& file : std::filesystem::directory_iterator(models_dir)) {
// Check that file is model config
if (file.is_directory() ||
file.path().filename().extension() != ".txt") continue;
// Get model name
auto model_name = file.path().filename().string();
model_name.erase(model_name.size()-4, 4);
clean_command_name(model_name);
// Parse model config
Bot::ModelConfig model_cfg;
std::ifstream cfgf(file.path());
if (!cfgf) {
std::cerr << "Failed to open model configuration file: " << file << std::endl;
exit(-2);
}
std::string filename;
for (std::string key; cfgf >> key;) {
// Read value
std::string value;
std::getline(cfgf, value);
// Erase all leading spaces
while (!value.empty() && (value[0] == ' ' || value[0] == '\t')) value.erase(0, 1);
// Check key and ignore comment lines
if (key == "filename") {
filename = std::move(value);
} else if (key == "user_prompt") {
model_cfg.user_prompt = std::move(value);
} else if (key == "bot_prompt") {
model_cfg.bot_prompt = std::move(value);
} else if (key == "instruct_mode_policy") {
model_cfg.instruct_mode_policy = parse_instruct_mode_policy(value);
} else if (!key.empty() && key[0] != '#') {
std::cerr << "Failed to parse model configuration file: Unknown key: " << key << std::endl;
exit(-3);
}
}
// Get full path
model_cfg.weight_path = file.path().parent_path()/filename;
// Safety checks
if (filename.empty() || !file_exists(model_cfg.weight_path)) {
std::cerr << "Failed to parse model configuration file: Invalid weight filename: " << model_name << std::endl;
exit(-8);
}
if (model_cfg.instruct_mode_policy != Bot::ModelConfig::InstructModePolicy::Forbid &&
(model_cfg.user_prompt.empty() || model_cfg.bot_prompt.empty())) {
std::cerr << "Failed to parse model configuration file: Instruct mode allowed but user prompt and bot prompt not given: " << model_name << std::endl;
exit(-9);
}
if (model_cfg.instruct_mode_policy != Bot::ModelConfig::InstructModePolicy::Force) {
allow_non_instruct = true;
}
// Add model to list
const auto& [stored_model_name, stored_model_cfg] = *models.emplace(std::move(model_name), std::move(model_cfg)).first;
// Set model pointer in config
if (stored_model_name == cfg.default_inference_model)
cfg.default_inference_model_cfg = &stored_model_cfg;
if (stored_model_name == cfg.translation_model)
cfg.translation_model_cfg = &stored_model_cfg;
}
// Safety checks
if (cfg.language != "EN") {
if (cfg.translation_model_cfg == nullptr) {
std::cerr << "Translation model required for non-english language, but is invalid" << std::endl;
exit(-5);
}
if (cfg.translation_model_cfg->instruct_mode_policy == Bot::ModelConfig::InstructModePolicy::Force) {
std::cerr << "Translation model is required to not have instruct mode forced" << std::endl;
exit(-10);
}
}
if (allow_non_instruct && !file_exists(cfg.prompt_file)) {
std::cerr << "Prompt file required when allowing non-instruct-mode use, but is invalid" << std::endl;
exit(-11);
}
if (!cfg.threads_only) {
if (cfg.default_inference_model_cfg == nullptr) {
std::cerr << "Default model required if not threads only, but is invalid" << std::endl;
exit(-6);
}
if (cfg.default_inference_model_cfg->instruct_mode_policy == Bot::ModelConfig::InstructModePolicy::Force) {
std::cerr << "Default model must not have instruct mode forced if not threads only" << std::endl;
exit(-7);
}
}
// Clean model names in config
clean_command_name(cfg.default_inference_model);
clean_command_name(cfg.translation_model);
// Construct and configure bot
Bot bot(cfg);
Bot bot(cfg, models);
// Start bot
bot.start();

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2017 aminroosta
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,60 @@
#pragma once
#include <string>
#include <stdexcept>
#include <sqlite3.h>
namespace sqlite {
class sqlite_exception: public std::runtime_error {
public:
sqlite_exception(const char* msg, std::string sql, int code = -1): runtime_error(msg), code(code), sql(sql) {}
sqlite_exception(int code, std::string sql): runtime_error(sqlite3_errstr(code)), code(code), sql(sql) {}
int get_code() const {return code & 0xFF;}
int get_extended_code() const {return code;}
std::string get_sql() const {return sql;}
private:
int code;
std::string sql;
};
namespace errors {
//One more or less trivial derived error class for each SQLITE error.
//Note the following are not errors so have no classes:
//SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE
//
//Note these names are exact matches to the names of the SQLITE error codes.
#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \
class name: public sqlite_exception { using sqlite_exception::sqlite_exception; };\
derived
#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \
class base ## _ ## sub: public base { using base::base; };
#include "lists/error_codes.h"
#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED
#undef SQLITE_MODERN_CPP_ERROR_CODE
//Some additional errors are here for the C++ interface
class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; };
class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; };
class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement
class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; };
static void throw_sqlite_error(const int& error_code, const std::string &sql = "") {
switch(error_code & 0xFF) {
#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \
case SQLITE_ ## NAME: switch(error_code) { \
derived \
default: throw name(error_code, sql); \
}
#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \
case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql);
#include "lists/error_codes.h"
#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED
#undef SQLITE_MODERN_CPP_ERROR_CODE
default: throw sqlite_exception(error_code, sql);
}
}
}
namespace exceptions = errors;
}

View file

@ -0,0 +1,93 @@
#if SQLITE_VERSION_NUMBER < 3010000
#define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8))
#define SQLITE_IOERR_AUTH (SQLITE_IOERR | (28<<8))
#define SQLITE_AUTH_USER (SQLITE_AUTH | (1<<8))
#endif
SQLITE_MODERN_CPP_ERROR_CODE(ERROR,error,)
SQLITE_MODERN_CPP_ERROR_CODE(INTERNAL,internal,)
SQLITE_MODERN_CPP_ERROR_CODE(PERM,perm,)
SQLITE_MODERN_CPP_ERROR_CODE(ABORT,abort,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(ABORT,ROLLBACK,abort,rollback)
)
SQLITE_MODERN_CPP_ERROR_CODE(BUSY,busy,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,RECOVERY,busy,recovery)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,SNAPSHOT,busy,snapshot)
)
SQLITE_MODERN_CPP_ERROR_CODE(LOCKED,locked,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(LOCKED,SHAREDCACHE,locked,sharedcache)
)
SQLITE_MODERN_CPP_ERROR_CODE(NOMEM,nomem,)
SQLITE_MODERN_CPP_ERROR_CODE(READONLY,readonly,)
SQLITE_MODERN_CPP_ERROR_CODE(INTERRUPT,interrupt,)
SQLITE_MODERN_CPP_ERROR_CODE(IOERR,ioerr,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,READ,ioerr,read)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHORT_READ,ioerr,short_read)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,WRITE,ioerr,write)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSYNC,ioerr,fsync)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_FSYNC,ioerr,dir_fsync)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,TRUNCATE,ioerr,truncate)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSTAT,ioerr,fstat)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,UNLOCK,ioerr,unlock)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,RDLOCK,ioerr,rdlock)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE,ioerr,delete)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,BLOCKED,ioerr,blocked)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,NOMEM,ioerr,nomem)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,ACCESS,ioerr,access)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CHECKRESERVEDLOCK,ioerr,checkreservedlock)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,LOCK,ioerr,lock)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CLOSE,ioerr,close)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_CLOSE,ioerr,dir_close)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMOPEN,ioerr,shmopen)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMSIZE,ioerr,shmsize)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMLOCK,ioerr,shmlock)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMMAP,ioerr,shmmap)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SEEK,ioerr,seek)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE_NOENT,ioerr,delete_noent)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,MMAP,ioerr,mmap)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,GETTEMPPATH,ioerr,gettemppath)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CONVPATH,ioerr,convpath)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,VNODE,ioerr,vnode)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,AUTH,ioerr,auth)
)
SQLITE_MODERN_CPP_ERROR_CODE(CORRUPT,corrupt,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CORRUPT,VTAB,corrupt,vtab)
)
SQLITE_MODERN_CPP_ERROR_CODE(NOTFOUND,notfound,)
SQLITE_MODERN_CPP_ERROR_CODE(FULL,full,)
SQLITE_MODERN_CPP_ERROR_CODE(CANTOPEN,cantopen,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,NOTEMPDIR,cantopen,notempdir)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,ISDIR,cantopen,isdir)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,FULLPATH,cantopen,fullpath)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,CONVPATH,cantopen,convpath)
)
SQLITE_MODERN_CPP_ERROR_CODE(PROTOCOL,protocol,)
SQLITE_MODERN_CPP_ERROR_CODE(EMPTY,empty,)
SQLITE_MODERN_CPP_ERROR_CODE(SCHEMA,schema,)
SQLITE_MODERN_CPP_ERROR_CODE(TOOBIG,toobig,)
SQLITE_MODERN_CPP_ERROR_CODE(CONSTRAINT,constraint,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,CHECK,constraint,check)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,COMMITHOOK,constraint,commithook)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FOREIGNKEY,constraint,foreignkey)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FUNCTION,constraint,function)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,NOTNULL,constraint,notnull)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,PRIMARYKEY,constraint,primarykey)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,TRIGGER,constraint,trigger)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,UNIQUE,constraint,unique)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,VTAB,constraint,vtab)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,ROWID,constraint,rowid)
)
SQLITE_MODERN_CPP_ERROR_CODE(MISMATCH,mismatch,)
SQLITE_MODERN_CPP_ERROR_CODE(MISUSE,misuse,)
SQLITE_MODERN_CPP_ERROR_CODE(NOLFS,nolfs,)
SQLITE_MODERN_CPP_ERROR_CODE(AUTH,auth,
)
SQLITE_MODERN_CPP_ERROR_CODE(FORMAT,format,)
SQLITE_MODERN_CPP_ERROR_CODE(RANGE,range,)
SQLITE_MODERN_CPP_ERROR_CODE(NOTADB,notadb,)
SQLITE_MODERN_CPP_ERROR_CODE(NOTICE,notice,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_WAL,notice,recover_wal)
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_ROLLBACK,notice,recover_rollback)
)
SQLITE_MODERN_CPP_ERROR_CODE(WARNING,warning,
SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(WARNING,AUTOINDEX,warning,autoindex)
)

View file

@ -0,0 +1,101 @@
#include "errors.h"
#include <sqlite3.h>
#include <utility>
#include <tuple>
#include <type_traits>
namespace sqlite {
namespace detail {
template<class>
using void_t = void;
template<class T, class = void>
struct is_callable : std::false_type {};
template<class Functor, class ...Arguments>
struct is_callable<Functor(Arguments...), void_t<decltype(std::declval<Functor>()(std::declval<Arguments>()...))>> : std::true_type {};
template<class Functor, class ...Functors>
class FunctorOverload: public Functor, public FunctorOverload<Functors...> {
public:
template<class Functor1, class ...Remaining>
FunctorOverload(Functor1 &&functor, Remaining &&... remaining):
Functor(std::forward<Functor1>(functor)),
FunctorOverload<Functors...>(std::forward<Remaining>(remaining)...) {}
using Functor::operator();
using FunctorOverload<Functors...>::operator();
};
template<class Functor>
class FunctorOverload<Functor>: public Functor {
public:
template<class Functor1>
FunctorOverload(Functor1 &&functor):
Functor(std::forward<Functor1>(functor)) {}
using Functor::operator();
};
template<class Functor>
class WrapIntoFunctor: public Functor {
public:
template<class Functor1>
WrapIntoFunctor(Functor1 &&functor):
Functor(std::forward<Functor1>(functor)) {}
using Functor::operator();
};
template<class ReturnType, class ...Arguments>
class WrapIntoFunctor<ReturnType(*)(Arguments...)> {
ReturnType(*ptr)(Arguments...);
public:
WrapIntoFunctor(ReturnType(*ptr)(Arguments...)): ptr(ptr) {}
ReturnType operator()(Arguments... arguments) { return (*ptr)(std::forward<Arguments>(arguments)...); }
};
inline void store_error_log_data_pointer(std::shared_ptr<void> ptr) {
static std::shared_ptr<void> stored;
stored = std::move(ptr);
}
template<class T>
std::shared_ptr<typename std::decay<T>::type> make_shared_inferred(T &&t) {
return std::make_shared<typename std::decay<T>::type>(std::forward<T>(t));
}
}
template<class Handler>
typename std::enable_if<!detail::is_callable<Handler(const sqlite_exception&)>::value>::type
error_log(Handler &&handler);
template<class Handler>
typename std::enable_if<detail::is_callable<Handler(const sqlite_exception&)>::value>::type
error_log(Handler &&handler);
template<class ...Handler>
typename std::enable_if<sizeof...(Handler)>=2>::type
error_log(Handler &&...handler) {
return error_log(detail::FunctorOverload<detail::WrapIntoFunctor<typename std::decay<Handler>::type>...>(std::forward<Handler>(handler)...));
}
template<class Handler>
typename std::enable_if<!detail::is_callable<Handler(const sqlite_exception&)>::value>::type
error_log(Handler &&handler) {
return error_log(std::forward<Handler>(handler), [](const sqlite_exception&) {});
}
template<class Handler>
typename std::enable_if<detail::is_callable<Handler(const sqlite_exception&)>::value>::type
error_log(Handler &&handler) {
auto ptr = detail::make_shared_inferred([handler = std::forward<Handler>(handler)](int error_code, const char *errstr) mutable {
switch(error_code & 0xFF) {
#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \
case SQLITE_ ## NAME: switch(error_code) { \
derived \
default: handler(errors::name(errstr, "", error_code)); \
};break;
#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \
case SQLITE_ ## BASE ## _ ## SUB: \
handler(errors::base ## _ ## sub(errstr, "", error_code)); \
break;
#include "lists/error_codes.h"
#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED
#undef SQLITE_MODERN_CPP_ERROR_CODE
default: handler(sqlite_exception(errstr, "", error_code)); \
}
});
sqlite3_config(SQLITE_CONFIG_LOG, (void(*)(void*,int,const char*))[](void *functor, int error_code, const char *errstr) {
(*static_cast<decltype(ptr.get())>(functor))(error_code, errstr);
}, ptr.get());
detail::store_error_log_data_pointer(std::move(ptr));
}
}

View file

@ -0,0 +1,44 @@
#pragma once
#ifndef SQLITE_HAS_CODEC
#define SQLITE_HAS_CODEC
#endif
#include "../sqlite_modern_cpp.h"
namespace sqlite {
struct sqlcipher_config : public sqlite_config {
std::string key;
};
class sqlcipher_database : public database {
public:
sqlcipher_database(std::string db, const sqlcipher_config &config): database(db, config) {
set_key(config.key);
}
sqlcipher_database(std::u16string db, const sqlcipher_config &config): database(db, config) {
set_key(config.key);
}
void set_key(const std::string &key) {
if(auto ret = sqlite3_key(_db.get(), key.data(), key.size()))
errors::throw_sqlite_error(ret);
}
void set_key(const std::string &key, const std::string &db_name) {
if(auto ret = sqlite3_key_v2(_db.get(), db_name.c_str(), key.data(), key.size()))
errors::throw_sqlite_error(ret);
}
void rekey(const std::string &new_key) {
if(auto ret = sqlite3_rekey(_db.get(), new_key.data(), new_key.size()))
errors::throw_sqlite_error(ret);
}
void rekey(const std::string &new_key, const std::string &db_name) {
if(auto ret = sqlite3_rekey_v2(_db.get(), db_name.c_str(), new_key.data(), new_key.size()))
errors::throw_sqlite_error(ret);
}
};
}

View file

@ -0,0 +1,55 @@
#pragma once
#include <tuple>
#include<type_traits>
namespace sqlite {
namespace utility {
template<typename> struct function_traits;
template <typename Function>
struct function_traits : public function_traits<
decltype(&std::remove_reference<Function>::type::operator())
> { };
template <
typename ClassType,
typename ReturnType,
typename... Arguments
>
struct function_traits<
ReturnType(ClassType::*)(Arguments...) const
> : function_traits<ReturnType(*)(Arguments...)> { };
/* support the non-const operator ()
* this will work with user defined functors */
template <
typename ClassType,
typename ReturnType,
typename... Arguments
>
struct function_traits<
ReturnType(ClassType::*)(Arguments...)
> : function_traits<ReturnType(*)(Arguments...)> { };
template <
typename ReturnType,
typename... Arguments
>
struct function_traits<
ReturnType(*)(Arguments...)
> {
typedef ReturnType result_type;
template <std::size_t Index>
using argument = typename std::tuple_element<
Index,
std::tuple<Arguments...>
>::type;
static const std::size_t arity = sizeof...(Arguments);
};
}
}

View file

@ -0,0 +1,27 @@
#pragma once
#include <cassert>
#include <exception>
#include <iostream>
namespace sqlite {
namespace utility {
#ifdef __cpp_lib_uncaught_exceptions
class UncaughtExceptionDetector {
public:
operator bool() {
return count != std::uncaught_exceptions();
}
private:
int count = std::uncaught_exceptions();
};
#else
class UncaughtExceptionDetector {
public:
operator bool() {
return std::uncaught_exception();
}
};
#endif
}
}

View file

@ -0,0 +1,42 @@
#pragma once
#include <locale>
#include <string>
#include <algorithm>
#include "../errors.h"
namespace sqlite {
namespace utility {
inline std::string utf16_to_utf8(const std::u16string &input) {
struct : std::codecvt<char16_t, char, std::mbstate_t> {
} codecvt;
std::mbstate_t state{};
std::string result((std::max)(input.size() * 3 / 2, std::size_t(4)), '\0');
const char16_t *remaining_input = input.data();
std::size_t produced_output = 0;
while(true) {
char *used_output;
switch(codecvt.out(state, remaining_input, &input[input.size()],
remaining_input, &result[produced_output],
&result[result.size() - 1] + 1, used_output)) {
case std::codecvt_base::ok:
result.resize(used_output - result.data());
return result;
case std::codecvt_base::noconv:
// This should be unreachable
case std::codecvt_base::error:
throw errors::invalid_utf16("Invalid UTF-16 input", "");
case std::codecvt_base::partial:
if(used_output == result.data() + produced_output)
throw errors::invalid_utf16("Unexpected end of input", "");
produced_output = used_output - result.data();
result.resize(
result.size()
+ (std::max)((&input[input.size()] - remaining_input) * 3 / 2,
std::ptrdiff_t(4)));
}
}
}
} // namespace utility
} // namespace sqlite

View file

@ -0,0 +1,201 @@
#pragma once
#include "../errors.h"
#include <sqlite3.h>
#include <optional>
#include <variant>
namespace sqlite::utility {
template<typename ...Options>
struct VariantFirstNullable {
using type = void;
};
template<typename T, typename ...Options>
struct VariantFirstNullable<T, Options...> {
using type = typename VariantFirstNullable<Options...>::type;
};
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
template<typename T, typename ...Options>
struct VariantFirstNullable<std::optional<T>, Options...> {
using type = std::optional<T>;
};
#endif
template<typename T, typename ...Options>
struct VariantFirstNullable<std::unique_ptr<T>, Options...> {
using type = std::unique_ptr<T>;
};
template<typename ...Options>
struct VariantFirstNullable<std::nullptr_t, Options...> {
using type = std::nullptr_t;
};
template<typename Callback, typename ...Options>
inline void variant_select_null(Callback&&callback) {
if constexpr(std::is_same_v<typename VariantFirstNullable<Options...>::type, void>) {
throw errors::mismatch("NULL is unsupported by this variant.", "", SQLITE_MISMATCH);
} else {
std::forward<Callback>(callback)(typename VariantFirstNullable<Options...>::type());
}
}
template<typename ...Options>
struct VariantFirstIntegerable {
using type = void;
};
template<typename T, typename ...Options>
struct VariantFirstIntegerable<T, Options...> {
using type = typename VariantFirstIntegerable<Options...>::type;
};
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
template<typename T, typename ...Options>
struct VariantFirstIntegerable<std::optional<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstIntegerable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstIntegerable<Options...>::type>;
};
#endif
template<typename T, typename ...Options>
struct VariantFirstIntegerable<std::enable_if_t<std::is_same_v<typename VariantFirstIntegerable<T, Options...>::type, T>>, std::unique_ptr<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstIntegerable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstIntegerable<Options...>::type>;
};
template<typename ...Options>
struct VariantFirstIntegerable<int, Options...> {
using type = int;
};
template<typename ...Options>
struct VariantFirstIntegerable<sqlite_int64, Options...> {
using type = sqlite_int64;
};
template<typename Callback, typename ...Options>
inline auto variant_select_integer(Callback&&callback) {
if constexpr(std::is_same_v<typename VariantFirstIntegerable<Options...>::type, void>) {
throw errors::mismatch("Integer is unsupported by this variant.", "", SQLITE_MISMATCH);
} else {
std::forward<Callback>(callback)(typename VariantFirstIntegerable<Options...>::type());
}
}
template<typename ...Options>
struct VariantFirstFloatable {
using type = void;
};
template<typename T, typename ...Options>
struct VariantFirstFloatable<T, Options...> {
using type = typename VariantFirstFloatable<Options...>::type;
};
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
template<typename T, typename ...Options>
struct VariantFirstFloatable<std::optional<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstFloatable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstFloatable<Options...>::type>;
};
#endif
template<typename T, typename ...Options>
struct VariantFirstFloatable<std::unique_ptr<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstFloatable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstFloatable<Options...>::type>;
};
template<typename ...Options>
struct VariantFirstFloatable<float, Options...> {
using type = float;
};
template<typename ...Options>
struct VariantFirstFloatable<double, Options...> {
using type = double;
};
template<typename Callback, typename ...Options>
inline auto variant_select_float(Callback&&callback) {
if constexpr(std::is_same_v<typename VariantFirstFloatable<Options...>::type, void>) {
throw errors::mismatch("Real is unsupported by this variant.", "", SQLITE_MISMATCH);
} else {
std::forward<Callback>(callback)(typename VariantFirstFloatable<Options...>::type());
}
}
template<typename ...Options>
struct VariantFirstTextable {
using type = void;
};
template<typename T, typename ...Options>
struct VariantFirstTextable<T, Options...> {
using type = typename VariantFirstTextable<void, Options...>::type;
};
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
template<typename T, typename ...Options>
struct VariantFirstTextable<std::optional<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstTextable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstTextable<Options...>::type>;
};
#endif
template<typename T, typename ...Options>
struct VariantFirstTextable<std::unique_ptr<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstTextable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstTextable<Options...>::type>;
};
template<typename ...Options>
struct VariantFirstTextable<std::string, Options...> {
using type = std::string;
};
template<typename ...Options>
struct VariantFirstTextable<std::u16string, Options...> {
using type = std::u16string;
};
template<typename Callback, typename ...Options>
inline void variant_select_text(Callback&&callback) {
if constexpr(std::is_same_v<typename VariantFirstTextable<Options...>::type, void>) {
throw errors::mismatch("Text is unsupported by this variant.", "", SQLITE_MISMATCH);
} else {
std::forward<Callback>(callback)(typename VariantFirstTextable<Options...>::type());
}
}
template<typename ...Options>
struct VariantFirstBlobable {
using type = void;
};
template<typename T, typename ...Options>
struct VariantFirstBlobable<T, Options...> {
using type = typename VariantFirstBlobable<Options...>::type;
};
#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT
template<typename T, typename ...Options>
struct VariantFirstBlobable<std::optional<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstBlobable<T, Options...>::type, T>, std::optional<T>, typename VariantFirstBlobable<Options...>::type>;
};
#endif
template<typename T, typename ...Options>
struct VariantFirstBlobable<std::unique_ptr<T>, Options...> {
using type = std::conditional_t<std::is_same_v<typename VariantFirstBlobable<T, Options...>::type, T>, std::unique_ptr<T>, typename VariantFirstBlobable<Options...>::type>;
};
template<typename T, typename A, typename ...Options>
struct VariantFirstBlobable<std::enable_if_t<std::is_pod_v<T>>, std::vector<T, A>, Options...> {
using type = std::vector<T, A>;
};
template<typename Callback, typename ...Options>
inline auto variant_select_blob(Callback&&callback) {
if constexpr(std::is_same_v<typename VariantFirstBlobable<Options...>::type, void>) {
throw errors::mismatch("Blob is unsupported by this variant.", "", SQLITE_MISMATCH);
} else {
std::forward<Callback>(callback)(typename VariantFirstBlobable<Options...>::type());
}
}
template<typename ...Options>
inline auto variant_select(int type) {
return [type](auto &&callback) {
using Callback = decltype(callback);
switch(type) {
case SQLITE_NULL:
variant_select_null<Callback, Options...>(std::forward<Callback>(callback));
break;
case SQLITE_INTEGER:
variant_select_integer<Callback, Options...>(std::forward<Callback>(callback));
break;
case SQLITE_FLOAT:
variant_select_float<Callback, Options...>(std::forward<Callback>(callback));
break;
case SQLITE_TEXT:
variant_select_text<Callback, Options...>(std::forward<Callback>(callback));
break;
case SQLITE_BLOB:
variant_select_blob<Callback, Options...>(std::forward<Callback>(callback));
break;
default:;
/* assert(false); */
}
};
}
}