mirror of
https://gitlab.com/niansa/libjustchat.git
synced 2025-03-06 20:48:31 +01:00
Initial commit
This commit is contained in:
commit
7968f66759
12 changed files with 450 additions and 0 deletions
75
.gitignore
vendored
Normal file
75
.gitignore
vendored
Normal file
|
@ -0,0 +1,75 @@
|
|||
# This file is used to ignore files which are generated
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
*~
|
||||
*.autosave
|
||||
*.a
|
||||
*.core
|
||||
*.moc
|
||||
*.o
|
||||
*.obj
|
||||
*.orig
|
||||
*.rej
|
||||
*.so
|
||||
*.so.*
|
||||
*_pch.h.cpp
|
||||
*_resource.rc
|
||||
*.qm
|
||||
.#*
|
||||
*.*#
|
||||
core
|
||||
!core/
|
||||
tags
|
||||
.DS_Store
|
||||
.directory
|
||||
*.debug
|
||||
Makefile*
|
||||
*.prl
|
||||
*.app
|
||||
moc_*.cpp
|
||||
ui_*.h
|
||||
qrc_*.cpp
|
||||
Thumbs.db
|
||||
*.res
|
||||
*.rc
|
||||
/.qmake.cache
|
||||
/.qmake.stash
|
||||
|
||||
# qtcreator generated files
|
||||
*.pro.user*
|
||||
CMakeLists.txt.user*
|
||||
|
||||
# xemacs temporary files
|
||||
*.flc
|
||||
|
||||
# Vim temporary files
|
||||
.*.swp
|
||||
|
||||
# Visual Studio generated files
|
||||
*.ib_pdb_index
|
||||
*.idb
|
||||
*.ilk
|
||||
*.pdb
|
||||
*.sln
|
||||
*.suo
|
||||
*.vcproj
|
||||
*vcproj.*.*.user
|
||||
*.ncb
|
||||
*.sdf
|
||||
*.opensdf
|
||||
*.vcxproj
|
||||
*vcxproj.*
|
||||
|
||||
# MinGW generated files
|
||||
*.Debug
|
||||
*.Release
|
||||
|
||||
# Python byte code
|
||||
*.pyc
|
||||
|
||||
# Binaries
|
||||
# --------
|
||||
*.dll
|
||||
*.exe
|
||||
|
||||
CMakeLists.txt.user*
|
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
[submodule "libjustlm"]
|
||||
path = libjustlm
|
||||
url = https://gitlab.com/niansa/libjustlm.git
|
||||
[submodule "commoncpp"]
|
||||
path = commoncpp
|
||||
url = https://gitlab.com/niansa/commoncpp.git
|
20
CMakeLists.txt
Normal file
20
CMakeLists.txt
Normal file
|
@ -0,0 +1,20 @@
|
|||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(libjustchat LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
add_subdirectory(libjustlm)
|
||||
add_subdirectory(commoncpp)
|
||||
|
||||
add_executable(justchat
|
||||
main.cpp
|
||||
# include/justchat/global_config.hpp global_config.cpp
|
||||
include/justchat/model_config.hpp model_config.cpp
|
||||
include/justchat/chat.hpp chat.cpp
|
||||
)
|
||||
|
||||
target_include_directories(justchat PUBLIC include/)
|
||||
target_include_directories(justchat PRIVATE include/justchat/)
|
||||
target_link_libraries(justchat PUBLIC commoncpp justlm)
|
75
chat.cpp
Normal file
75
chat.cpp
Normal file
|
@ -0,0 +1,75 @@
|
|||
#include "chat.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
namespace LM {
|
||||
namespace Chat {
|
||||
LM::Inference::Params Inference::get_params() const {
|
||||
LM::Inference::Params fres;
|
||||
if (config.max_context_size.has_value())
|
||||
fres.n_ctx = config.max_context_size.value();
|
||||
if (config.repeat_last.has_value())
|
||||
fres.n_repeat_last = config.repeat_last.value();
|
||||
if (config.eos_ignores.has_value())
|
||||
fres.n_eos_ignores = config.eos_ignores.value();
|
||||
if (config.top_k.has_value())
|
||||
fres.top_k = config.top_k.value();
|
||||
if (config.top_p.has_value())
|
||||
fres.top_p = config.top_p.value();
|
||||
if (config.temp.has_value())
|
||||
fres.temp = config.temp.value();
|
||||
if (config.mirostat_learning_rate.has_value())
|
||||
fres.mirostat_learning_rate = config.mirostat_learning_rate.value();
|
||||
if (config.mirostat_target_entropy.has_value())
|
||||
fres.mirostat_target_entropy = config.mirostat_target_entropy.value();
|
||||
if (config.repeat_penalty.has_value())
|
||||
fres.repeat_penalty = config.repeat_penalty.value();
|
||||
if (config.mirostat_version.has_value())
|
||||
fres.prefer_mirostat = config.mirostat_version.value();
|
||||
return fres;
|
||||
}
|
||||
|
||||
Inference::Inference(LM::InferencePool& pool, const std::string& config_path) : pool(&pool) {
|
||||
config.parse(config_path);
|
||||
}
|
||||
|
||||
Inference::Inference(const std::string &config_path) {
|
||||
config.parse(config_path);
|
||||
}
|
||||
|
||||
void Inference::reset() {
|
||||
if (pool) {
|
||||
const auto id = get_id();
|
||||
pool->delete_inference(id);
|
||||
pool->create_inference(id, config.model_file, get_params());
|
||||
} else
|
||||
inference = LM::Inference::construct(config.model_file, get_params());
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) Inference::instruct(const std::string& instruction, const std::function<bool (std::string_view)> on_generate, const std::function<bool (float, bool)> on_evaluate) {
|
||||
auto inference = get_underlaying();
|
||||
bool non_cancelled = true;
|
||||
// Append prompt prefix
|
||||
inference->append(config.prompt_prefix, [on_evaluate] (float progress) -> bool {
|
||||
on_evaluate(progress/3.0f, true);
|
||||
return true; // Can't be cancelled here
|
||||
});
|
||||
// Append prompt
|
||||
inference->append(instruction, [on_evaluate, &non_cancelled] (float progress) -> bool {
|
||||
return non_cancelled = on_evaluate(progress/3.0f+100.0f/3.0f, false);
|
||||
});
|
||||
// Append prompt suffix
|
||||
inference->append(config.prompt_suffix, [on_evaluate] (float progress) -> bool {
|
||||
on_evaluate(progress/3.0f+(100.0f/3.0f*2.0f), true);
|
||||
return true; // Can't be cancelled here
|
||||
});
|
||||
// Check for cancellation
|
||||
if (!non_cancelled) LM_CORETURN "";
|
||||
// Run inference
|
||||
LM_CORETURN inference->run(config.prompt_prefix, [on_generate] (const char *token) {
|
||||
return on_generate(token);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
1
commoncpp
Submodule
1
commoncpp
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit ec148c3c3447b26b3213b0277566445456693bdf
|
18
global_config.cpp
Normal file
18
global_config.cpp
Normal file
|
@ -0,0 +1,18 @@
|
|||
#include "global_config.hpp"
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
|
||||
|
||||
namespace LM {
|
||||
namespace Chat {
|
||||
void GlobalConfig::fill(KeyValueMap &&map, bool ignore_extra) {
|
||||
for (auto& [key, value] : map) {
|
||||
if (key == "system_prompt")
|
||||
system_prompt = std::move(value);
|
||||
else if (!ignore_extra)
|
||||
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
64
include/justchat/chat.hpp
Normal file
64
include/justchat/chat.hpp
Normal file
|
@ -0,0 +1,64 @@
|
|||
#ifndef CHAT_HPP
|
||||
#define CHAT_HPP
|
||||
#include "model_config.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <commoncpp/utils.hpp>
|
||||
#include <justlm.hpp>
|
||||
#include <justlm_pool.hpp>
|
||||
|
||||
|
||||
namespace LM {
|
||||
namespace Chat {
|
||||
class Inference {
|
||||
ModelConfig config;
|
||||
|
||||
LM::Inference *inference = nullptr;
|
||||
LM::InferencePool *pool = nullptr;
|
||||
|
||||
LM::Inference::Params get_params() const;
|
||||
|
||||
size_t get_id() const {
|
||||
return static_cast<size_t>(reinterpret_cast<uintptr_t>(this)); // Dirty hack that just works
|
||||
}
|
||||
|
||||
public:
|
||||
class OptionallySharedInference {
|
||||
std::shared_ptr<LM::Inference> s_ptr;
|
||||
LM::Inference *r_ptr;
|
||||
|
||||
public:
|
||||
OptionallySharedInference(std::shared_ptr<LM::Inference> &&s)
|
||||
: s_ptr(std::move(s)) {
|
||||
r_ptr = s_ptr.get();
|
||||
}
|
||||
OptionallySharedInference(LM::Inference *r)
|
||||
: r_ptr(std::move(r)) {}
|
||||
|
||||
LM::Inference *operator ->() const {
|
||||
return r_ptr;
|
||||
}
|
||||
};
|
||||
|
||||
Inference(const std::string& config_path);
|
||||
Inference(LM::InferencePool& pool, const std::string& config_path);
|
||||
|
||||
void reset();
|
||||
LM_SCHEDULABLE(std::string) instruct(const std::string& instruction, const std::function<bool(std::string_view)> on_generate = nullptr, const std::function<bool(float, bool)> on_evaluate = nullptr);
|
||||
|
||||
void start() {
|
||||
reset();
|
||||
}
|
||||
|
||||
OptionallySharedInference get_underlaying() const {
|
||||
if (inference) return inference;
|
||||
if (pool) return pool->get_inference(get_id());
|
||||
common::utils::unreachable();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif // CHAT_HPP
|
19
include/justchat/global_config.hpp
Normal file
19
include/justchat/global_config.hpp
Normal file
|
@ -0,0 +1,19 @@
|
|||
#ifndef MODEL_CONFIG_HPP
|
||||
#define MODEL_CONFIG_HPP
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include <commoncpp/config.hpp>
|
||||
|
||||
|
||||
namespace LM {
|
||||
namespace Chat {
|
||||
class GlobalConfig final : public common::Configuration {
|
||||
protected:
|
||||
void fill(KeyValueMap&&, bool ignore_extra = false) override;
|
||||
|
||||
public:
|
||||
std::string system_prompt;
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif // MODEL_CONFIG_HPP
|
44
include/justchat/model_config.hpp
Normal file
44
include/justchat/model_config.hpp
Normal file
|
@ -0,0 +1,44 @@
|
|||
#ifndef MODEL_CONFIG_HPP
|
||||
#define MODEL_CONFIG_HPP
|
||||
#include <string>
|
||||
#include <optional>
|
||||
#include <commoncpp/config.hpp>
|
||||
|
||||
|
||||
namespace LM {
|
||||
namespace Chat {
|
||||
class ModelConfig final : public common::Configuration {
|
||||
protected:
|
||||
void fill(KeyValueMap&&, bool ignore_extra = false) override;
|
||||
void check() const override;
|
||||
|
||||
public:
|
||||
ModelConfig() {
|
||||
ignore_environment = true;
|
||||
}
|
||||
|
||||
void parse(const std::string& file) override;
|
||||
|
||||
std::string config_file;
|
||||
|
||||
std::string model_file,
|
||||
language = "en";
|
||||
std::string mutable prompt_prefix = "### Human:\n",
|
||||
prompt_suffix = "\n### Assistant:\n";
|
||||
bool allow_system_prompt = true,
|
||||
strict_prompt = false;
|
||||
|
||||
std::optional<unsigned> max_context_size,
|
||||
repeat_last, // How many tokens to repeat-penalize
|
||||
eos_ignores, // How many times to ignore EOS
|
||||
top_k,
|
||||
mirostat_version; // Version of mirostat to use; 0 for none
|
||||
std::optional<float> top_p,
|
||||
temp,
|
||||
mirostat_learning_rate,
|
||||
mirostat_target_entropy,
|
||||
repeat_penalty;
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif // MODEL_CONFIG_HPP
|
1
libjustlm
Submodule
1
libjustlm
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 0199db02b7a81e31d394a23b671c65786a664893
|
32
main.cpp
Normal file
32
main.cpp
Normal file
|
@ -0,0 +1,32 @@
|
|||
#include "chat.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Check usage
|
||||
if (argc != 2) {
|
||||
std::cerr << "Usage: " << argv[0] << " <model config path>" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
// Get args
|
||||
const auto model_config = argv[1];
|
||||
// Construct chat model
|
||||
LM::Chat::Inference model(model_config);
|
||||
model.start();
|
||||
std::string instruction;
|
||||
for (;;) {
|
||||
std::cout << "> ";
|
||||
std::getline(std::cin, instruction);
|
||||
model.instruct(instruction, [] (auto token) {
|
||||
std::cout << token << std::flush;
|
||||
return true;
|
||||
}, [] (float progress, bool) {
|
||||
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
|
||||
return true;
|
||||
});
|
||||
std::cout << '\n';
|
||||
}
|
||||
}
|
95
model_config.cpp
Normal file
95
model_config.cpp
Normal file
|
@ -0,0 +1,95 @@
|
|||
#include "model_config.hpp"
|
||||
|
||||
#include <filesystem>
|
||||
#include <iostream>
|
||||
|
||||
|
||||
|
||||
namespace LM {
|
||||
namespace Chat {
|
||||
void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
|
||||
for (auto& [key, value] : map) {
|
||||
if (key == "model_file")
|
||||
model_file = std::move(value);
|
||||
else if (key == "prompt_prefix")
|
||||
prompt_prefix = parse_string(value);
|
||||
else if (key == "prompt_suffix")
|
||||
prompt_suffix = parse_string(value);
|
||||
else if (key == "language")
|
||||
language = std::move(value);
|
||||
else if (key == "allow_system_prompt")
|
||||
allow_system_prompt = parse_bool(value);
|
||||
else if (key == "strict_prompt")
|
||||
strict_prompt = parse_bool(value);
|
||||
else if (key == "max_context_size")
|
||||
max_context_size = std::stoul(value);
|
||||
else if (key == "repeat_last")
|
||||
repeat_last = std::stoul(value);
|
||||
else if (key == "eos_ignores")
|
||||
eos_ignores = std::stoi(value); // -1 for "infinite"
|
||||
else if (key == "top_k")
|
||||
top_k = std::stoi(value);
|
||||
else if (key == "mirostat_version")
|
||||
mirostat_version = std::stoi(value);
|
||||
else if (key == "top_p")
|
||||
top_p = std::stof(value);
|
||||
else if (key == "temp")
|
||||
temp = std::stof(value);
|
||||
else if (key == "mirostat_learning_rate")
|
||||
mirostat_learning_rate = std::stof(value);
|
||||
else if (key == "mirostat_target_entropy")
|
||||
mirostat_target_entropy = std::stof(value);
|
||||
else if (key == "repeat_penalty")
|
||||
repeat_penalty = std::stof(value);
|
||||
else if (!ignore_extra)
|
||||
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
|
||||
}
|
||||
|
||||
// Make path absolute relative to config file
|
||||
if (!model_file.empty() && model_file.find(std::filesystem::path::preferred_separator) == model_file.npos)
|
||||
model_file = std::filesystem::path(config_file).parent_path()/model_file;
|
||||
}
|
||||
|
||||
void ModelConfig::check() const {
|
||||
if (!file_exists(model_file))
|
||||
throw Exception("Needs valid model file");
|
||||
|
||||
if (prompt_prefix.empty())
|
||||
throw Exception("There should be a prompt prefix, use \"none\" to enforce empty");
|
||||
if (prompt_prefix == "none" || prompt_prefix == "\"none\"")
|
||||
prompt_prefix.clear();
|
||||
|
||||
if (prompt_suffix.empty())
|
||||
throw Exception("There should be a prompt suffix, use \"none\" to enforce empty");
|
||||
if (prompt_suffix == "none" || prompt_suffix == "\"none\"")
|
||||
prompt_suffix.clear();
|
||||
|
||||
if (language.size() != 2 || !islower(language[0]) || !islower(language[1]))
|
||||
throw Exception("Specified language needs to be lowercase two-letter code (example: en)");
|
||||
|
||||
if (mirostat_version.has_value()) {
|
||||
if (mirostat_version.value() > 2)
|
||||
throw Exception("Mirostat version must be 2 or below, use 0 to disable");
|
||||
if (mirostat_version.value() != 0 && (top_p.has_value() || top_k.has_value()))
|
||||
throw Exception("Can't combine top_p/top_k sampling with mirostat");
|
||||
}
|
||||
|
||||
if (top_p.has_value() && top_p.value() < 0.0f)
|
||||
throw Exception("The top_p must be a positive value!");
|
||||
|
||||
if (temp.has_value() && temp.value() <= 0.0f)
|
||||
throw Exception("The temperature must be a value above 0!");
|
||||
|
||||
if (mirostat_learning_rate.has_value() && mirostat_learning_rate.value() <= 0.0f)
|
||||
throw Exception("The learning rate must be a value above 0!");
|
||||
|
||||
if (mirostat_target_entropy.has_value() && mirostat_target_entropy.value() <= 0.0f)
|
||||
throw Exception("The target entropy must be a value above 0!");
|
||||
}
|
||||
|
||||
void ModelConfig::parse(const std::string &file) {
|
||||
config_file = file;
|
||||
Configuration::parse(file);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue