mirror of
https://gitlab.com/niansa/libjustchat.git
synced 2025-03-06 20:48:31 +01:00
95 lines
3.6 KiB
C++
95 lines
3.6 KiB
C++
#include "model_config.hpp"
|
|
|
|
#include <filesystem>
|
|
#include <iostream>
|
|
|
|
|
|
|
|
namespace LM {
|
|
namespace Chat {
|
|
void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
|
|
for (auto& [key, value] : map) {
|
|
if (key == "model_file")
|
|
model_file = std::move(value);
|
|
else if (key == "prompt_prefix")
|
|
prompt.prefix = parse_string(value);
|
|
else if (key == "prompt_suffix")
|
|
prompt.suffix = parse_string(value);
|
|
else if (key == "language")
|
|
language = std::move(value);
|
|
else if (key == "allow_system_prompt")
|
|
allow_system_prompt = parse_bool(value);
|
|
else if (key == "strict_prompt")
|
|
strict_prompt = parse_bool(value);
|
|
else if (key == "max_context_size")
|
|
max_context_size = std::stoul(value);
|
|
else if (key == "repeat_last")
|
|
repeat_last = std::stoul(value);
|
|
else if (key == "emits_eos")
|
|
emits_eos = parse_bool(value);
|
|
else if (key == "top_k")
|
|
top_k = std::stoi(value);
|
|
else if (key == "mirostat_version")
|
|
mirostat_version = std::stoi(value);
|
|
else if (key == "top_p")
|
|
top_p = std::stof(value);
|
|
else if (key == "temp")
|
|
temp = std::stof(value);
|
|
else if (key == "mirostat_learning_rate")
|
|
mirostat_learning_rate = std::stof(value);
|
|
else if (key == "mirostat_target_entropy")
|
|
mirostat_target_entropy = std::stof(value);
|
|
else if (key == "repeat_penalty")
|
|
repeat_penalty = std::stof(value);
|
|
else if (!ignore_extra)
|
|
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
|
|
}
|
|
|
|
// Make path absolute relative to config file
|
|
if (!model_file.empty() && model_file.find(std::filesystem::path::preferred_separator) == model_file.npos)
|
|
model_file = std::filesystem::path(config_file).parent_path()/model_file;
|
|
}
|
|
|
|
void ModelConfig::check() const {
|
|
if (!file_exists(model_file))
|
|
throw Exception("Needs valid model file");
|
|
|
|
if (prompt.prefix.empty())
|
|
throw Exception("There should be a prompt prefix, use \"none\" to enforce empty");
|
|
if (prompt.prefix == "none" || prompt.prefix == "\"none\"")
|
|
prompt.prefix.clear();
|
|
|
|
if (prompt.suffix.empty())
|
|
throw Exception("There should be a prompt suffix, use \"none\" to enforce empty");
|
|
if (prompt.suffix == "none" || prompt.suffix == "\"none\"")
|
|
prompt.suffix.clear();
|
|
|
|
if (language.size() != 2 || !islower(language[0]) || !islower(language[1]))
|
|
throw Exception("Specified language needs to be lowercase two-letter code (example: en)");
|
|
|
|
if (mirostat_version.has_value()) {
|
|
if (mirostat_version.value() > 2)
|
|
throw Exception("Mirostat version must be 2 or below, use 0 to disable");
|
|
if (mirostat_version.value() != 0 && (top_p.has_value() || top_k.has_value()))
|
|
throw Exception("Can't combine top_p/top_k sampling with mirostat");
|
|
}
|
|
|
|
if (top_p.has_value() && top_p.value() < 0.0f)
|
|
throw Exception("The top_p must be a positive value!");
|
|
|
|
if (temp.has_value() && temp.value() <= 0.0f)
|
|
throw Exception("The temperature must be a value above 0!");
|
|
|
|
if (mirostat_learning_rate.has_value() && mirostat_learning_rate.value() <= 0.0f)
|
|
throw Exception("The learning rate must be a value above 0!");
|
|
|
|
if (mirostat_target_entropy.has_value() && mirostat_target_entropy.value() <= 0.0f)
|
|
throw Exception("The target entropy must be a value above 0!");
|
|
}
|
|
|
|
void ModelConfig::parse(const std::string &file) {
|
|
config_file = file;
|
|
Configuration::parse(file);
|
|
}
|
|
}
|
|
}
|