mirror of
https://gitlab.com/niansa/anyproc.git
synced 2025-03-06 20:49:24 +01:00
256 lines
9.6 KiB
C++
256 lines
9.6 KiB
C++
#ifndef ANYPROC_HPP
|
|
#define ANYPROC_HPP
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
#include <string_view>
|
|
#include <utility>
|
|
#include <memory>
|
|
#include <functional>
|
|
#include <fstream> // for debugging purposes, to be removed
|
|
|
|
#include <justlm.hpp>
|
|
|
|
|
|
|
|
class PyEval {
|
|
std::unique_ptr<LM::Inference> i;
|
|
std::string buffer = "Python 3.9.2 (default, Feb 28 2021, 17:03:44)\n"
|
|
"[GCC 10.2.1 20210110] on linux\n"
|
|
"Type \"help\", \"copyright\", \"credits\" or \"license\" for more information.\n"
|
|
">>> print(\"Hello world!\")\n"
|
|
"Hello world!\n";
|
|
|
|
public:
|
|
static inline const std::string escape(std::string_view str, const char quotes = '"') {
|
|
std::string fres;
|
|
fres.push_back(quotes);
|
|
for (const char c : str) {
|
|
switch (c) {
|
|
case '\n': fres.append("\\n"); break;
|
|
case '\r': fres.append("\\r"); break;
|
|
case '\t': fres.append("\\t"); break;
|
|
case '\'': case '"': if (c == quotes) fres.push_back('\\'); [[fallthrough]];
|
|
default: fres.push_back(c);
|
|
}
|
|
}
|
|
fres.push_back(quotes);
|
|
return fres;
|
|
}
|
|
static inline const std::string unescape(std::string_view str) {
|
|
unsigned backTruncateCount = 0;
|
|
while (*(str.end()-backTruncateCount-1) == ' ') backTruncateCount++;
|
|
const char quotes = str[0];
|
|
if (quotes != '"' && quotes != '\'') {
|
|
return std::string(str);
|
|
}
|
|
std::string fres;
|
|
bool after_bs = false;
|
|
for (char c : std::string_view{str.data()+1, str.size()-2-backTruncateCount}) {
|
|
if (c == '\\') {
|
|
after_bs = true;
|
|
continue;
|
|
}
|
|
if (after_bs) {
|
|
after_bs = false;
|
|
switch (c) {
|
|
case 'n': c = '\n'; break;
|
|
case 'r': c = '\r'; break;
|
|
case 't': c = '\t'; break;
|
|
default: continue;
|
|
}
|
|
}
|
|
fres.push_back(c);
|
|
}
|
|
return fres;
|
|
}
|
|
|
|
constexpr static LM::Inference::Params get_recommended_params() {
|
|
LM::Inference::Params p;
|
|
p.use_mlock = false;
|
|
p.temp = 0.0f;
|
|
return p;
|
|
}
|
|
|
|
PyEval(const std::string& weights_path, const LM::Inference::Params& p = get_recommended_params()) {
|
|
i.reset(LM::Inference::construct(weights_path, p));
|
|
}
|
|
|
|
auto& begin() {
|
|
buffer += ">>> ";
|
|
return *this;
|
|
}
|
|
|
|
auto& load_module(std::string_view name) {
|
|
buffer += "import ";
|
|
buffer += name;
|
|
return *this;
|
|
}
|
|
auto& load_module(std::string_view name, std::string_view alias) {
|
|
buffer += "import ";
|
|
buffer += name;
|
|
buffer += " as ";
|
|
buffer += alias;
|
|
return *this;
|
|
}
|
|
|
|
auto& expression(std::string_view expression) {
|
|
buffer += expression;
|
|
return *this;
|
|
}
|
|
|
|
LM_SCHEDULABLE(std::string) run(const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
|
|
buffer += "\n";
|
|
LM_COAWAIT i->append(buffer, on_append_tick);
|
|
buffer.clear();
|
|
std::ofstream("prompt.txt") << i->get_prompt();
|
|
LM_CORETURN LM_COAWAIT i->run("\n", on_generation_tick);
|
|
}
|
|
void example(std::string_view response = "") {
|
|
buffer += "\n";
|
|
buffer += response;
|
|
if (!response.empty()) buffer += "\n";
|
|
}
|
|
void finish() {
|
|
buffer += "\n";
|
|
}
|
|
|
|
LM_SCHEDULABLE(void) create_savestate(LM::Inference::Savestate &sv, const std::function<bool (float)> &on_append_tick = nullptr) {
|
|
if (!buffer.empty()) {
|
|
LM_COAWAIT i->append(buffer, on_append_tick);
|
|
buffer.clear();
|
|
}
|
|
LM_COAWAIT i->create_savestate(sv);
|
|
}
|
|
LM_SCHEDULABLE(void) restore_savestate(const LM::Inference::Savestate &sv) {
|
|
LM_COAWAIT i->restore_savestate(sv);
|
|
}
|
|
|
|
LM::Inference& get_inference() {
|
|
return *i;
|
|
}
|
|
};
|
|
|
|
class Dictionary : PyEval {
|
|
static inline auto word_lookup_exprgen(const std::string& escaped_word) {
|
|
return "dict.word_lookup("+escaped_word+")";
|
|
}
|
|
|
|
public:
|
|
constexpr static LM::Inference::Params get_params() {
|
|
auto p = get_recommended_params();
|
|
p.top_k = 5;
|
|
p.top_p = 0.2f;
|
|
p.temp = 0.2f;
|
|
p.repeat_penalty = 1.2f;
|
|
p.n_repeat_last = 64;
|
|
return p;
|
|
}
|
|
|
|
Dictionary(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) {
|
|
begin()
|
|
.load_module("huge_dictionary", "dict")
|
|
.finish();
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"Treehouse\"")+".description")
|
|
.example("'A small house, especially one for children to play in, built or placed up in the branches of a tree.'");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"Python\"")+".description")
|
|
.example("'Any of several Old World boa constrictors of the subfamily Pythoninae, often growing to a length of more than 20 feet (6 meters): the Indian python, Python molurus, is endangered.'");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"C\"")+".description")
|
|
.example("'The 3rd letter of the alphabet.'");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"Appletree\"")+".syllables")
|
|
.example("['Ap', 'ple', 'tree']");
|
|
}
|
|
|
|
LM_SCHEDULABLE(std::string) lookup(std::string_view word, const std::string& what, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
|
|
LM_CORETURN LM_COAWAIT begin().expression(word_lookup_exprgen(escape(word))+'.'+what)
|
|
.run(on_append_tick, on_generation_tick);
|
|
}
|
|
};
|
|
|
|
class Translator : PyEval {
|
|
LM::Inference::Savestate sv;
|
|
std::unordered_map<size_t/*hash*/, std::string> cache;
|
|
|
|
static inline auto translation_exprgen(const std::string& escaped_text, const std::string& escaped_language) {
|
|
return "translator.translate("+escaped_text+", "+escaped_language+")";
|
|
}
|
|
|
|
public:
|
|
constexpr static LM::Inference::Params get_params() {
|
|
auto p = get_recommended_params();
|
|
p.top_k = 5;
|
|
p.top_p = 0.2f;
|
|
p.temp = 0.1f;
|
|
p.repeat_penalty = 1.2f;
|
|
p.n_repeat_last = 64;
|
|
p.n_eos_ignores = 1;
|
|
return p;
|
|
}
|
|
|
|
Translator(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) {
|
|
begin()
|
|
.load_module("deepl_scrape", "translator")
|
|
.finish();
|
|
begin()
|
|
.expression("# Using translator.translate, we can reliably translate any language into any other language. The meaning of the text is going to stay the same and if the text is already in the target language, it remains unchanged.")
|
|
.finish();
|
|
begin()
|
|
.expression(translation_exprgen("\"Treehouse\"", "\"DE\""))
|
|
.example("'Baumhaus'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Mir fiel ein Ball auf den Fuss!\"", "\"EN\""))
|
|
.example("'A ball fell onto my foot!'");
|
|
begin()
|
|
.expression(translation_exprgen("\"He's not an NPC, he's just another user here :/\"", "\"DE\""))
|
|
.example("'Er ist kein NPC, er ist hier nur ein normaler Benutzer :/'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Du bist ein sehr freundlicher Mensch :-)\"", "\"EN\""))
|
|
.example("'You are a very kind human :-)'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Hi\"", "\"EN\""))
|
|
.example("'Hi'");
|
|
begin()
|
|
.expression(translation_exprgen("\"What is the root of nine\"", "\"DE\""))
|
|
.example("'Was ist die Wurzel von neun'");
|
|
begin()
|
|
.expression(translation_exprgen("\"How long until school starts?\"", "\"IT\""))
|
|
.example("'Quanto manca all'inizio della scuola?'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Poisoning someone else means killing him/her, so you shouldn't ever try that!\"", "\"DE\""))
|
|
.example("'Jemanden zu vergiften bedeutet, ihn/sie zu töten, also solltest du das niemals versuchen!'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Please wait...\"", "\"DE\""))
|
|
.example("'Bitte warten...'");
|
|
}
|
|
|
|
LM_SCHEDULABLE(std::string) translate(std::string_view text, std::string_view language, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
|
|
// Hash
|
|
size_t hash = std::hash<std::string_view>{}(text) + std::hash<std::string_view>{}(language);
|
|
|
|
// Cache lookup
|
|
auto res = cache.find(hash);
|
|
if (res != cache.end()) {
|
|
LM_CORETURN res->second;
|
|
}
|
|
|
|
// Restore savestate
|
|
if (sv.is_valid()) LM_COAWAIT restore_savestate(sv);
|
|
else LM_COAWAIT create_savestate(sv, on_append_tick);
|
|
|
|
// Run inference
|
|
auto fres = unescape(LM_COAWAIT begin().expression(translation_exprgen(escape(text), escape(language)))
|
|
.run(on_append_tick, on_generation_tick));
|
|
|
|
// Add to cache
|
|
cache[hash] = fres;
|
|
|
|
// Return final result;
|
|
LM_CORETURN fres;
|
|
}
|
|
};
|
|
#endif // ANYPROC_HPP
|