1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustchat.git synced 2025-03-06 20:48:31 +01:00
libjustchat/model_config.cpp
2023-06-10 13:36:32 +02:00

95 lines
3.6 KiB
C++

#include "model_config.hpp"
#include <filesystem>
#include <iostream>
namespace LM {
namespace Chat {
void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
for (auto& [key, value] : map) {
if (key == "model_file")
model_file = std::move(value);
else if (key == "prompt_prefix")
prompt.prefix = parse_string(value);
else if (key == "prompt_suffix")
prompt.suffix = parse_string(value);
else if (key == "language")
language = std::move(value);
else if (key == "allow_system_prompt")
allow_system_prompt = parse_bool(value);
else if (key == "strict_prompt")
strict_prompt = parse_bool(value);
else if (key == "max_context_size")
max_context_size = std::stoul(value);
else if (key == "repeat_last")
repeat_last = std::stoul(value);
else if (key == "emits_eos")
emits_eos = parse_bool(value);
else if (key == "top_k")
top_k = std::stoi(value);
else if (key == "mirostat_version")
mirostat_version = std::stoi(value);
else if (key == "top_p")
top_p = std::stof(value);
else if (key == "temp")
temp = std::stof(value);
else if (key == "mirostat_learning_rate")
mirostat_learning_rate = std::stof(value);
else if (key == "mirostat_target_entropy")
mirostat_target_entropy = std::stof(value);
else if (key == "repeat_penalty")
repeat_penalty = std::stof(value);
else if (!ignore_extra)
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
}
// Make path absolute relative to config file
if (!model_file.empty() && model_file.find(std::filesystem::path::preferred_separator) == model_file.npos)
model_file = std::filesystem::path(config_file).parent_path()/model_file;
}
void ModelConfig::check() const {
if (!file_exists(model_file))
throw Exception("Needs valid model file");
if (prompt.prefix.empty())
throw Exception("There should be a prompt prefix, use \"none\" to enforce empty");
if (prompt.prefix == "none" || prompt.prefix == "\"none\"")
prompt.prefix.clear();
if (prompt.suffix.empty())
throw Exception("There should be a prompt suffix, use \"none\" to enforce empty");
if (prompt.suffix == "none" || prompt.suffix == "\"none\"")
prompt.suffix.clear();
if (language.size() != 2 || !islower(language[0]) || !islower(language[1]))
throw Exception("Specified language needs to be lowercase two-letter code (example: en)");
if (mirostat_version.has_value()) {
if (mirostat_version.value() > 2)
throw Exception("Mirostat version must be 2 or below, use 0 to disable");
if (mirostat_version.value() != 0 && (top_p.has_value() || top_k.has_value()))
throw Exception("Can't combine top_p/top_k sampling with mirostat");
}
if (top_p.has_value() && top_p.value() < 0.0f)
throw Exception("The top_p must be a positive value!");
if (temp.has_value() && temp.value() <= 0.0f)
throw Exception("The temperature must be a value above 0!");
if (mirostat_learning_rate.has_value() && mirostat_learning_rate.value() <= 0.0f)
throw Exception("The learning rate must be a value above 0!");
if (mirostat_target_entropy.has_value() && mirostat_target_entropy.value() <= 0.0f)
throw Exception("The target entropy must be a value above 0!");
}
void ModelConfig::parse(const std::string &file) {
config_file = file;
Configuration::parse(file);
}
}
}