mirror of
https://gitlab.com/niansa/libjustchat.git
synced 2025-03-06 20:48:31 +01:00
Initial commit
This commit is contained in:
commit
7968f66759
12 changed files with 450 additions and 0 deletions
75
.gitignore
vendored
Normal file
75
.gitignore
vendored
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
# This file is used to ignore files which are generated
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
*~
|
||||||
|
*.autosave
|
||||||
|
*.a
|
||||||
|
*.core
|
||||||
|
*.moc
|
||||||
|
*.o
|
||||||
|
*.obj
|
||||||
|
*.orig
|
||||||
|
*.rej
|
||||||
|
*.so
|
||||||
|
*.so.*
|
||||||
|
*_pch.h.cpp
|
||||||
|
*_resource.rc
|
||||||
|
*.qm
|
||||||
|
.#*
|
||||||
|
*.*#
|
||||||
|
core
|
||||||
|
!core/
|
||||||
|
tags
|
||||||
|
.DS_Store
|
||||||
|
.directory
|
||||||
|
*.debug
|
||||||
|
Makefile*
|
||||||
|
*.prl
|
||||||
|
*.app
|
||||||
|
moc_*.cpp
|
||||||
|
ui_*.h
|
||||||
|
qrc_*.cpp
|
||||||
|
Thumbs.db
|
||||||
|
*.res
|
||||||
|
*.rc
|
||||||
|
/.qmake.cache
|
||||||
|
/.qmake.stash
|
||||||
|
|
||||||
|
# qtcreator generated files
|
||||||
|
*.pro.user*
|
||||||
|
CMakeLists.txt.user*
|
||||||
|
|
||||||
|
# xemacs temporary files
|
||||||
|
*.flc
|
||||||
|
|
||||||
|
# Vim temporary files
|
||||||
|
.*.swp
|
||||||
|
|
||||||
|
# Visual Studio generated files
|
||||||
|
*.ib_pdb_index
|
||||||
|
*.idb
|
||||||
|
*.ilk
|
||||||
|
*.pdb
|
||||||
|
*.sln
|
||||||
|
*.suo
|
||||||
|
*.vcproj
|
||||||
|
*vcproj.*.*.user
|
||||||
|
*.ncb
|
||||||
|
*.sdf
|
||||||
|
*.opensdf
|
||||||
|
*.vcxproj
|
||||||
|
*vcxproj.*
|
||||||
|
|
||||||
|
# MinGW generated files
|
||||||
|
*.Debug
|
||||||
|
*.Release
|
||||||
|
|
||||||
|
# Python byte code
|
||||||
|
*.pyc
|
||||||
|
|
||||||
|
# Binaries
|
||||||
|
# --------
|
||||||
|
*.dll
|
||||||
|
*.exe
|
||||||
|
|
||||||
|
CMakeLists.txt.user*
|
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[submodule "libjustlm"]
|
||||||
|
path = libjustlm
|
||||||
|
url = https://gitlab.com/niansa/libjustlm.git
|
||||||
|
[submodule "commoncpp"]
|
||||||
|
path = commoncpp
|
||||||
|
url = https://gitlab.com/niansa/commoncpp.git
|
20
CMakeLists.txt
Normal file
20
CMakeLists.txt
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
cmake_minimum_required(VERSION 3.5)
|
||||||
|
|
||||||
|
project(libjustchat LANGUAGES CXX)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
add_subdirectory(libjustlm)
|
||||||
|
add_subdirectory(commoncpp)
|
||||||
|
|
||||||
|
add_executable(justchat
|
||||||
|
main.cpp
|
||||||
|
# include/justchat/global_config.hpp global_config.cpp
|
||||||
|
include/justchat/model_config.hpp model_config.cpp
|
||||||
|
include/justchat/chat.hpp chat.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(justchat PUBLIC include/)
|
||||||
|
target_include_directories(justchat PRIVATE include/justchat/)
|
||||||
|
target_link_libraries(justchat PUBLIC commoncpp justlm)
|
75
chat.cpp
Normal file
75
chat.cpp
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
#include "chat.hpp"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace LM {
|
||||||
|
namespace Chat {
|
||||||
|
LM::Inference::Params Inference::get_params() const {
|
||||||
|
LM::Inference::Params fres;
|
||||||
|
if (config.max_context_size.has_value())
|
||||||
|
fres.n_ctx = config.max_context_size.value();
|
||||||
|
if (config.repeat_last.has_value())
|
||||||
|
fres.n_repeat_last = config.repeat_last.value();
|
||||||
|
if (config.eos_ignores.has_value())
|
||||||
|
fres.n_eos_ignores = config.eos_ignores.value();
|
||||||
|
if (config.top_k.has_value())
|
||||||
|
fres.top_k = config.top_k.value();
|
||||||
|
if (config.top_p.has_value())
|
||||||
|
fres.top_p = config.top_p.value();
|
||||||
|
if (config.temp.has_value())
|
||||||
|
fres.temp = config.temp.value();
|
||||||
|
if (config.mirostat_learning_rate.has_value())
|
||||||
|
fres.mirostat_learning_rate = config.mirostat_learning_rate.value();
|
||||||
|
if (config.mirostat_target_entropy.has_value())
|
||||||
|
fres.mirostat_target_entropy = config.mirostat_target_entropy.value();
|
||||||
|
if (config.repeat_penalty.has_value())
|
||||||
|
fres.repeat_penalty = config.repeat_penalty.value();
|
||||||
|
if (config.mirostat_version.has_value())
|
||||||
|
fres.prefer_mirostat = config.mirostat_version.value();
|
||||||
|
return fres;
|
||||||
|
}
|
||||||
|
|
||||||
|
Inference::Inference(LM::InferencePool& pool, const std::string& config_path) : pool(&pool) {
|
||||||
|
config.parse(config_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
Inference::Inference(const std::string &config_path) {
|
||||||
|
config.parse(config_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Inference::reset() {
|
||||||
|
if (pool) {
|
||||||
|
const auto id = get_id();
|
||||||
|
pool->delete_inference(id);
|
||||||
|
pool->create_inference(id, config.model_file, get_params());
|
||||||
|
} else
|
||||||
|
inference = LM::Inference::construct(config.model_file, get_params());
|
||||||
|
}
|
||||||
|
|
||||||
|
LM_SCHEDULABLE(std::string) Inference::instruct(const std::string& instruction, const std::function<bool (std::string_view)> on_generate, const std::function<bool (float, bool)> on_evaluate) {
|
||||||
|
auto inference = get_underlaying();
|
||||||
|
bool non_cancelled = true;
|
||||||
|
// Append prompt prefix
|
||||||
|
inference->append(config.prompt_prefix, [on_evaluate] (float progress) -> bool {
|
||||||
|
on_evaluate(progress/3.0f, true);
|
||||||
|
return true; // Can't be cancelled here
|
||||||
|
});
|
||||||
|
// Append prompt
|
||||||
|
inference->append(instruction, [on_evaluate, &non_cancelled] (float progress) -> bool {
|
||||||
|
return non_cancelled = on_evaluate(progress/3.0f+100.0f/3.0f, false);
|
||||||
|
});
|
||||||
|
// Append prompt suffix
|
||||||
|
inference->append(config.prompt_suffix, [on_evaluate] (float progress) -> bool {
|
||||||
|
on_evaluate(progress/3.0f+(100.0f/3.0f*2.0f), true);
|
||||||
|
return true; // Can't be cancelled here
|
||||||
|
});
|
||||||
|
// Check for cancellation
|
||||||
|
if (!non_cancelled) LM_CORETURN "";
|
||||||
|
// Run inference
|
||||||
|
LM_CORETURN inference->run(config.prompt_prefix, [on_generate] (const char *token) {
|
||||||
|
return on_generate(token);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
1
commoncpp
Submodule
1
commoncpp
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit ec148c3c3447b26b3213b0277566445456693bdf
|
18
global_config.cpp
Normal file
18
global_config.cpp
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#include "global_config.hpp"
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace LM {
|
||||||
|
namespace Chat {
|
||||||
|
void GlobalConfig::fill(KeyValueMap &&map, bool ignore_extra) {
|
||||||
|
for (auto& [key, value] : map) {
|
||||||
|
if (key == "system_prompt")
|
||||||
|
system_prompt = std::move(value);
|
||||||
|
else if (!ignore_extra)
|
||||||
|
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
64
include/justchat/chat.hpp
Normal file
64
include/justchat/chat.hpp
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
#ifndef CHAT_HPP
|
||||||
|
#define CHAT_HPP
|
||||||
|
#include "model_config.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <memory>
|
||||||
|
#include <functional>
|
||||||
|
#include <commoncpp/utils.hpp>
|
||||||
|
#include <justlm.hpp>
|
||||||
|
#include <justlm_pool.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
namespace LM {
|
||||||
|
namespace Chat {
|
||||||
|
class Inference {
|
||||||
|
ModelConfig config;
|
||||||
|
|
||||||
|
LM::Inference *inference = nullptr;
|
||||||
|
LM::InferencePool *pool = nullptr;
|
||||||
|
|
||||||
|
LM::Inference::Params get_params() const;
|
||||||
|
|
||||||
|
size_t get_id() const {
|
||||||
|
return static_cast<size_t>(reinterpret_cast<uintptr_t>(this)); // Dirty hack that just works
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
class OptionallySharedInference {
|
||||||
|
std::shared_ptr<LM::Inference> s_ptr;
|
||||||
|
LM::Inference *r_ptr;
|
||||||
|
|
||||||
|
public:
|
||||||
|
OptionallySharedInference(std::shared_ptr<LM::Inference> &&s)
|
||||||
|
: s_ptr(std::move(s)) {
|
||||||
|
r_ptr = s_ptr.get();
|
||||||
|
}
|
||||||
|
OptionallySharedInference(LM::Inference *r)
|
||||||
|
: r_ptr(std::move(r)) {}
|
||||||
|
|
||||||
|
LM::Inference *operator ->() const {
|
||||||
|
return r_ptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Inference(const std::string& config_path);
|
||||||
|
Inference(LM::InferencePool& pool, const std::string& config_path);
|
||||||
|
|
||||||
|
void reset();
|
||||||
|
LM_SCHEDULABLE(std::string) instruct(const std::string& instruction, const std::function<bool(std::string_view)> on_generate = nullptr, const std::function<bool(float, bool)> on_evaluate = nullptr);
|
||||||
|
|
||||||
|
void start() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
OptionallySharedInference get_underlaying() const {
|
||||||
|
if (inference) return inference;
|
||||||
|
if (pool) return pool->get_inference(get_id());
|
||||||
|
common::utils::unreachable();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // CHAT_HPP
|
19
include/justchat/global_config.hpp
Normal file
19
include/justchat/global_config.hpp
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
#ifndef MODEL_CONFIG_HPP
|
||||||
|
#define MODEL_CONFIG_HPP
|
||||||
|
#include <string>
|
||||||
|
#include <optional>
|
||||||
|
#include <commoncpp/config.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
namespace LM {
|
||||||
|
namespace Chat {
|
||||||
|
class GlobalConfig final : public common::Configuration {
|
||||||
|
protected:
|
||||||
|
void fill(KeyValueMap&&, bool ignore_extra = false) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::string system_prompt;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // MODEL_CONFIG_HPP
|
44
include/justchat/model_config.hpp
Normal file
44
include/justchat/model_config.hpp
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
#ifndef MODEL_CONFIG_HPP
|
||||||
|
#define MODEL_CONFIG_HPP
|
||||||
|
#include <string>
|
||||||
|
#include <optional>
|
||||||
|
#include <commoncpp/config.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
namespace LM {
|
||||||
|
namespace Chat {
|
||||||
|
class ModelConfig final : public common::Configuration {
|
||||||
|
protected:
|
||||||
|
void fill(KeyValueMap&&, bool ignore_extra = false) override;
|
||||||
|
void check() const override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ModelConfig() {
|
||||||
|
ignore_environment = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void parse(const std::string& file) override;
|
||||||
|
|
||||||
|
std::string config_file;
|
||||||
|
|
||||||
|
std::string model_file,
|
||||||
|
language = "en";
|
||||||
|
std::string mutable prompt_prefix = "### Human:\n",
|
||||||
|
prompt_suffix = "\n### Assistant:\n";
|
||||||
|
bool allow_system_prompt = true,
|
||||||
|
strict_prompt = false;
|
||||||
|
|
||||||
|
std::optional<unsigned> max_context_size,
|
||||||
|
repeat_last, // How many tokens to repeat-penalize
|
||||||
|
eos_ignores, // How many times to ignore EOS
|
||||||
|
top_k,
|
||||||
|
mirostat_version; // Version of mirostat to use; 0 for none
|
||||||
|
std::optional<float> top_p,
|
||||||
|
temp,
|
||||||
|
mirostat_learning_rate,
|
||||||
|
mirostat_target_entropy,
|
||||||
|
repeat_penalty;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // MODEL_CONFIG_HPP
|
1
libjustlm
Submodule
1
libjustlm
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 0199db02b7a81e31d394a23b671c65786a664893
|
32
main.cpp
Normal file
32
main.cpp
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
#include "chat.hpp"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
// Check usage
|
||||||
|
if (argc != 2) {
|
||||||
|
std::cerr << "Usage: " << argv[0] << " <model config path>" << std::endl;
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
// Get args
|
||||||
|
const auto model_config = argv[1];
|
||||||
|
// Construct chat model
|
||||||
|
LM::Chat::Inference model(model_config);
|
||||||
|
model.start();
|
||||||
|
std::string instruction;
|
||||||
|
for (;;) {
|
||||||
|
std::cout << "> ";
|
||||||
|
std::getline(std::cin, instruction);
|
||||||
|
model.instruct(instruction, [] (auto token) {
|
||||||
|
std::cout << token << std::flush;
|
||||||
|
return true;
|
||||||
|
}, [] (float progress, bool) {
|
||||||
|
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
std::cout << '\n';
|
||||||
|
}
|
||||||
|
}
|
95
model_config.cpp
Normal file
95
model_config.cpp
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
#include "model_config.hpp"
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace LM {
|
||||||
|
namespace Chat {
|
||||||
|
void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
|
||||||
|
for (auto& [key, value] : map) {
|
||||||
|
if (key == "model_file")
|
||||||
|
model_file = std::move(value);
|
||||||
|
else if (key == "prompt_prefix")
|
||||||
|
prompt_prefix = parse_string(value);
|
||||||
|
else if (key == "prompt_suffix")
|
||||||
|
prompt_suffix = parse_string(value);
|
||||||
|
else if (key == "language")
|
||||||
|
language = std::move(value);
|
||||||
|
else if (key == "allow_system_prompt")
|
||||||
|
allow_system_prompt = parse_bool(value);
|
||||||
|
else if (key == "strict_prompt")
|
||||||
|
strict_prompt = parse_bool(value);
|
||||||
|
else if (key == "max_context_size")
|
||||||
|
max_context_size = std::stoul(value);
|
||||||
|
else if (key == "repeat_last")
|
||||||
|
repeat_last = std::stoul(value);
|
||||||
|
else if (key == "eos_ignores")
|
||||||
|
eos_ignores = std::stoi(value); // -1 for "infinite"
|
||||||
|
else if (key == "top_k")
|
||||||
|
top_k = std::stoi(value);
|
||||||
|
else if (key == "mirostat_version")
|
||||||
|
mirostat_version = std::stoi(value);
|
||||||
|
else if (key == "top_p")
|
||||||
|
top_p = std::stof(value);
|
||||||
|
else if (key == "temp")
|
||||||
|
temp = std::stof(value);
|
||||||
|
else if (key == "mirostat_learning_rate")
|
||||||
|
mirostat_learning_rate = std::stof(value);
|
||||||
|
else if (key == "mirostat_target_entropy")
|
||||||
|
mirostat_target_entropy = std::stof(value);
|
||||||
|
else if (key == "repeat_penalty")
|
||||||
|
repeat_penalty = std::stof(value);
|
||||||
|
else if (!ignore_extra)
|
||||||
|
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make path absolute relative to config file
|
||||||
|
if (!model_file.empty() && model_file.find(std::filesystem::path::preferred_separator) == model_file.npos)
|
||||||
|
model_file = std::filesystem::path(config_file).parent_path()/model_file;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelConfig::check() const {
|
||||||
|
if (!file_exists(model_file))
|
||||||
|
throw Exception("Needs valid model file");
|
||||||
|
|
||||||
|
if (prompt_prefix.empty())
|
||||||
|
throw Exception("There should be a prompt prefix, use \"none\" to enforce empty");
|
||||||
|
if (prompt_prefix == "none" || prompt_prefix == "\"none\"")
|
||||||
|
prompt_prefix.clear();
|
||||||
|
|
||||||
|
if (prompt_suffix.empty())
|
||||||
|
throw Exception("There should be a prompt suffix, use \"none\" to enforce empty");
|
||||||
|
if (prompt_suffix == "none" || prompt_suffix == "\"none\"")
|
||||||
|
prompt_suffix.clear();
|
||||||
|
|
||||||
|
if (language.size() != 2 || !islower(language[0]) || !islower(language[1]))
|
||||||
|
throw Exception("Specified language needs to be lowercase two-letter code (example: en)");
|
||||||
|
|
||||||
|
if (mirostat_version.has_value()) {
|
||||||
|
if (mirostat_version.value() > 2)
|
||||||
|
throw Exception("Mirostat version must be 2 or below, use 0 to disable");
|
||||||
|
if (mirostat_version.value() != 0 && (top_p.has_value() || top_k.has_value()))
|
||||||
|
throw Exception("Can't combine top_p/top_k sampling with mirostat");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (top_p.has_value() && top_p.value() < 0.0f)
|
||||||
|
throw Exception("The top_p must be a positive value!");
|
||||||
|
|
||||||
|
if (temp.has_value() && temp.value() <= 0.0f)
|
||||||
|
throw Exception("The temperature must be a value above 0!");
|
||||||
|
|
||||||
|
if (mirostat_learning_rate.has_value() && mirostat_learning_rate.value() <= 0.0f)
|
||||||
|
throw Exception("The learning rate must be a value above 0!");
|
||||||
|
|
||||||
|
if (mirostat_target_entropy.has_value() && mirostat_target_entropy.value() <= 0.0f)
|
||||||
|
throw Exception("The target entropy must be a value above 0!");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelConfig::parse(const std::string &file) {
|
||||||
|
config_file = file;
|
||||||
|
Configuration::parse(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue