1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustchat.git synced 2025-03-06 20:48:31 +01:00

Initial commit

This commit is contained in:
niansa 2023-06-10 00:49:41 +02:00
commit 7968f66759
12 changed files with 450 additions and 0 deletions

75
.gitignore vendored Normal file
View file

@ -0,0 +1,75 @@
# This file is used to ignore files which are generated
# ----------------------------------------------------------------------------
*~
*.autosave
*.a
*.core
*.moc
*.o
*.obj
*.orig
*.rej
*.so
*.so.*
*_pch.h.cpp
*_resource.rc
*.qm
.#*
*.*#
core
!core/
tags
.DS_Store
.directory
*.debug
Makefile*
*.prl
*.app
moc_*.cpp
ui_*.h
qrc_*.cpp
Thumbs.db
*.res
*.rc
/.qmake.cache
/.qmake.stash
# qtcreator generated files
*.pro.user*
CMakeLists.txt.user*
# xemacs temporary files
*.flc
# Vim temporary files
.*.swp
# Visual Studio generated files
*.ib_pdb_index
*.idb
*.ilk
*.pdb
*.sln
*.suo
*.vcproj
*vcproj.*.*.user
*.ncb
*.sdf
*.opensdf
*.vcxproj
*vcxproj.*
# MinGW generated files
*.Debug
*.Release
# Python byte code
*.pyc
# Binaries
# --------
*.dll
*.exe
CMakeLists.txt.user*

6
.gitmodules vendored Normal file
View file

@ -0,0 +1,6 @@
[submodule "libjustlm"]
path = libjustlm
url = https://gitlab.com/niansa/libjustlm.git
[submodule "commoncpp"]
path = commoncpp
url = https://gitlab.com/niansa/commoncpp.git

20
CMakeLists.txt Normal file
View file

@ -0,0 +1,20 @@
cmake_minimum_required(VERSION 3.5)
project(libjustchat LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_subdirectory(libjustlm)
add_subdirectory(commoncpp)
add_executable(justchat
main.cpp
# include/justchat/global_config.hpp global_config.cpp
include/justchat/model_config.hpp model_config.cpp
include/justchat/chat.hpp chat.cpp
)
target_include_directories(justchat PUBLIC include/)
target_include_directories(justchat PRIVATE include/justchat/)
target_link_libraries(justchat PUBLIC commoncpp justlm)

75
chat.cpp Normal file
View file

@ -0,0 +1,75 @@
#include "chat.hpp"
#include <iostream>
namespace LM {
namespace Chat {
LM::Inference::Params Inference::get_params() const {
LM::Inference::Params fres;
if (config.max_context_size.has_value())
fres.n_ctx = config.max_context_size.value();
if (config.repeat_last.has_value())
fres.n_repeat_last = config.repeat_last.value();
if (config.eos_ignores.has_value())
fres.n_eos_ignores = config.eos_ignores.value();
if (config.top_k.has_value())
fres.top_k = config.top_k.value();
if (config.top_p.has_value())
fres.top_p = config.top_p.value();
if (config.temp.has_value())
fres.temp = config.temp.value();
if (config.mirostat_learning_rate.has_value())
fres.mirostat_learning_rate = config.mirostat_learning_rate.value();
if (config.mirostat_target_entropy.has_value())
fres.mirostat_target_entropy = config.mirostat_target_entropy.value();
if (config.repeat_penalty.has_value())
fres.repeat_penalty = config.repeat_penalty.value();
if (config.mirostat_version.has_value())
fres.prefer_mirostat = config.mirostat_version.value();
return fres;
}
Inference::Inference(LM::InferencePool& pool, const std::string& config_path) : pool(&pool) {
config.parse(config_path);
}
Inference::Inference(const std::string &config_path) {
config.parse(config_path);
}
void Inference::reset() {
if (pool) {
const auto id = get_id();
pool->delete_inference(id);
pool->create_inference(id, config.model_file, get_params());
} else
inference = LM::Inference::construct(config.model_file, get_params());
}
LM_SCHEDULABLE(std::string) Inference::instruct(const std::string& instruction, const std::function<bool (std::string_view)> on_generate, const std::function<bool (float, bool)> on_evaluate) {
auto inference = get_underlaying();
bool non_cancelled = true;
// Append prompt prefix
inference->append(config.prompt_prefix, [on_evaluate] (float progress) -> bool {
on_evaluate(progress/3.0f, true);
return true; // Can't be cancelled here
});
// Append prompt
inference->append(instruction, [on_evaluate, &non_cancelled] (float progress) -> bool {
return non_cancelled = on_evaluate(progress/3.0f+100.0f/3.0f, false);
});
// Append prompt suffix
inference->append(config.prompt_suffix, [on_evaluate] (float progress) -> bool {
on_evaluate(progress/3.0f+(100.0f/3.0f*2.0f), true);
return true; // Can't be cancelled here
});
// Check for cancellation
if (!non_cancelled) LM_CORETURN "";
// Run inference
LM_CORETURN inference->run(config.prompt_prefix, [on_generate] (const char *token) {
return on_generate(token);
});
}
}
}

1
commoncpp Submodule

@ -0,0 +1 @@
Subproject commit ec148c3c3447b26b3213b0277566445456693bdf

18
global_config.cpp Normal file
View file

@ -0,0 +1,18 @@
#include "global_config.hpp"
#include <filesystem>
namespace LM {
namespace Chat {
void GlobalConfig::fill(KeyValueMap &&map, bool ignore_extra) {
for (auto& [key, value] : map) {
if (key == "system_prompt")
system_prompt = std::move(value);
else if (!ignore_extra)
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
}
}
}
}

64
include/justchat/chat.hpp Normal file
View file

@ -0,0 +1,64 @@
#ifndef CHAT_HPP
#define CHAT_HPP
#include "model_config.hpp"
#include <string>
#include <string_view>
#include <memory>
#include <functional>
#include <commoncpp/utils.hpp>
#include <justlm.hpp>
#include <justlm_pool.hpp>
namespace LM {
namespace Chat {
class Inference {
ModelConfig config;
LM::Inference *inference = nullptr;
LM::InferencePool *pool = nullptr;
LM::Inference::Params get_params() const;
size_t get_id() const {
return static_cast<size_t>(reinterpret_cast<uintptr_t>(this)); // Dirty hack that just works
}
public:
class OptionallySharedInference {
std::shared_ptr<LM::Inference> s_ptr;
LM::Inference *r_ptr;
public:
OptionallySharedInference(std::shared_ptr<LM::Inference> &&s)
: s_ptr(std::move(s)) {
r_ptr = s_ptr.get();
}
OptionallySharedInference(LM::Inference *r)
: r_ptr(std::move(r)) {}
LM::Inference *operator ->() const {
return r_ptr;
}
};
Inference(const std::string& config_path);
Inference(LM::InferencePool& pool, const std::string& config_path);
void reset();
LM_SCHEDULABLE(std::string) instruct(const std::string& instruction, const std::function<bool(std::string_view)> on_generate = nullptr, const std::function<bool(float, bool)> on_evaluate = nullptr);
void start() {
reset();
}
OptionallySharedInference get_underlaying() const {
if (inference) return inference;
if (pool) return pool->get_inference(get_id());
common::utils::unreachable();
}
};
}
}
#endif // CHAT_HPP

View file

@ -0,0 +1,19 @@
#ifndef MODEL_CONFIG_HPP
#define MODEL_CONFIG_HPP
#include <string>
#include <optional>
#include <commoncpp/config.hpp>
namespace LM {
namespace Chat {
class GlobalConfig final : public common::Configuration {
protected:
void fill(KeyValueMap&&, bool ignore_extra = false) override;
public:
std::string system_prompt;
};
}
}
#endif // MODEL_CONFIG_HPP

View file

@ -0,0 +1,44 @@
#ifndef MODEL_CONFIG_HPP
#define MODEL_CONFIG_HPP
#include <string>
#include <optional>
#include <commoncpp/config.hpp>
namespace LM {
namespace Chat {
class ModelConfig final : public common::Configuration {
protected:
void fill(KeyValueMap&&, bool ignore_extra = false) override;
void check() const override;
public:
ModelConfig() {
ignore_environment = true;
}
void parse(const std::string& file) override;
std::string config_file;
std::string model_file,
language = "en";
std::string mutable prompt_prefix = "### Human:\n",
prompt_suffix = "\n### Assistant:\n";
bool allow_system_prompt = true,
strict_prompt = false;
std::optional<unsigned> max_context_size,
repeat_last, // How many tokens to repeat-penalize
eos_ignores, // How many times to ignore EOS
top_k,
mirostat_version; // Version of mirostat to use; 0 for none
std::optional<float> top_p,
temp,
mirostat_learning_rate,
mirostat_target_entropy,
repeat_penalty;
};
}
}
#endif // MODEL_CONFIG_HPP

1
libjustlm Submodule

@ -0,0 +1 @@
Subproject commit 0199db02b7a81e31d394a23b671c65786a664893

32
main.cpp Normal file
View file

@ -0,0 +1,32 @@
#include "chat.hpp"
#include <iostream>
#include <string>
int main(int argc, char **argv) {
// Check usage
if (argc != 2) {
std::cerr << "Usage: " << argv[0] << " <model config path>" << std::endl;
return EXIT_FAILURE;
}
// Get args
const auto model_config = argv[1];
// Construct chat model
LM::Chat::Inference model(model_config);
model.start();
std::string instruction;
for (;;) {
std::cout << "> ";
std::getline(std::cin, instruction);
model.instruct(instruction, [] (auto token) {
std::cout << token << std::flush;
return true;
}, [] (float progress, bool) {
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
return true;
});
std::cout << '\n';
}
}

95
model_config.cpp Normal file
View file

@ -0,0 +1,95 @@
#include "model_config.hpp"
#include <filesystem>
#include <iostream>
namespace LM {
namespace Chat {
void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
for (auto& [key, value] : map) {
if (key == "model_file")
model_file = std::move(value);
else if (key == "prompt_prefix")
prompt_prefix = parse_string(value);
else if (key == "prompt_suffix")
prompt_suffix = parse_string(value);
else if (key == "language")
language = std::move(value);
else if (key == "allow_system_prompt")
allow_system_prompt = parse_bool(value);
else if (key == "strict_prompt")
strict_prompt = parse_bool(value);
else if (key == "max_context_size")
max_context_size = std::stoul(value);
else if (key == "repeat_last")
repeat_last = std::stoul(value);
else if (key == "eos_ignores")
eos_ignores = std::stoi(value); // -1 for "infinite"
else if (key == "top_k")
top_k = std::stoi(value);
else if (key == "mirostat_version")
mirostat_version = std::stoi(value);
else if (key == "top_p")
top_p = std::stof(value);
else if (key == "temp")
temp = std::stof(value);
else if (key == "mirostat_learning_rate")
mirostat_learning_rate = std::stof(value);
else if (key == "mirostat_target_entropy")
mirostat_target_entropy = std::stof(value);
else if (key == "repeat_penalty")
repeat_penalty = std::stof(value);
else if (!ignore_extra)
throw Exception("Error: Failed to parse texts file: Unknown key: "+key);
}
// Make path absolute relative to config file
if (!model_file.empty() && model_file.find(std::filesystem::path::preferred_separator) == model_file.npos)
model_file = std::filesystem::path(config_file).parent_path()/model_file;
}
void ModelConfig::check() const {
if (!file_exists(model_file))
throw Exception("Needs valid model file");
if (prompt_prefix.empty())
throw Exception("There should be a prompt prefix, use \"none\" to enforce empty");
if (prompt_prefix == "none" || prompt_prefix == "\"none\"")
prompt_prefix.clear();
if (prompt_suffix.empty())
throw Exception("There should be a prompt suffix, use \"none\" to enforce empty");
if (prompt_suffix == "none" || prompt_suffix == "\"none\"")
prompt_suffix.clear();
if (language.size() != 2 || !islower(language[0]) || !islower(language[1]))
throw Exception("Specified language needs to be lowercase two-letter code (example: en)");
if (mirostat_version.has_value()) {
if (mirostat_version.value() > 2)
throw Exception("Mirostat version must be 2 or below, use 0 to disable");
if (mirostat_version.value() != 0 && (top_p.has_value() || top_k.has_value()))
throw Exception("Can't combine top_p/top_k sampling with mirostat");
}
if (top_p.has_value() && top_p.value() < 0.0f)
throw Exception("The top_p must be a positive value!");
if (temp.has_value() && temp.value() <= 0.0f)
throw Exception("The temperature must be a value above 0!");
if (mirostat_learning_rate.has_value() && mirostat_learning_rate.value() <= 0.0f)
throw Exception("The learning rate must be a value above 0!");
if (mirostat_target_entropy.has_value() && mirostat_target_entropy.value() <= 0.0f)
throw Exception("The target entropy must be a value above 0!");
}
void ModelConfig::parse(const std::string &file) {
config_file = file;
Configuration::parse(file);
}
}
}