mirror of
https://gitlab.com/niansa/anyproc.git
synced 2025-03-06 20:49:24 +01:00
224 lines
8.4 KiB
C++
224 lines
8.4 KiB
C++
#ifndef ANYPROC_HPP
|
|
#define ANYPROC_HPP
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
#include <string_view>
|
|
#include <utility>
|
|
#include <functional>
|
|
|
|
#include <justlm.hpp>
|
|
|
|
|
|
|
|
class PyEval : LM::Inference {
|
|
std::string buffer = "Python 3.9.2 (default, Feb 28 2021, 17:03:44)\n"
|
|
"[GCC 10.2.1 20210110] on linux\n"
|
|
"Type \"help\", \"copyright\", \"credits\" or \"license\" for more information.\n"
|
|
">>> print(\"Hello world!\")\n"
|
|
"Hello world!\n";
|
|
|
|
public:
|
|
static inline const std::string escape(std::string_view str, const char quotes = '"') {
|
|
std::string fres;
|
|
fres.push_back(quotes);
|
|
for (const char c : str) {
|
|
if (c == quotes) fres.push_back('\\');
|
|
fres.push_back(c);
|
|
}
|
|
fres.push_back(quotes);
|
|
return fres;
|
|
}
|
|
static inline const std::string unescape(std::string_view str) {
|
|
unsigned backTruncateCount = 0;
|
|
while (*(str.end()-backTruncateCount-1) == ' ') backTruncateCount++;
|
|
const char quotes = str[0];
|
|
if (quotes != '"' && quotes != '\'') {
|
|
return std::string(str);
|
|
}
|
|
std::string fres;
|
|
for (const char c : std::string_view{str.data()+1, str.size()-2-backTruncateCount}) {
|
|
if (c != '\\') fres.push_back(c);
|
|
}
|
|
return fres;
|
|
}
|
|
|
|
constexpr static LM::Inference::Params get_recommended_params() {
|
|
LM::Inference::Params p;
|
|
p.use_mlock = false;
|
|
p.repeat_penalty = 1.2f;
|
|
p.n_repeat_last = 64;
|
|
p.temp = 0.4f;
|
|
p.top_k = 1;
|
|
p.top_p = 0.f;
|
|
return p;
|
|
}
|
|
|
|
PyEval(const std::string& weights_path, const Params& p = get_recommended_params())
|
|
: LM::Inference(weights_path, p) {}
|
|
|
|
auto& begin() {
|
|
buffer += ">>> ";
|
|
return *this;
|
|
}
|
|
|
|
auto& load_module(std::string_view name) {
|
|
buffer += "import ";
|
|
buffer += name;
|
|
return *this;
|
|
}
|
|
auto& load_module(std::string_view name, std::string_view alias) {
|
|
buffer += "import ";
|
|
buffer += name;
|
|
buffer += " as ";
|
|
buffer += alias;
|
|
return *this;
|
|
}
|
|
|
|
auto& expression(std::string_view expression) {
|
|
buffer += expression;
|
|
return *this;
|
|
}
|
|
|
|
auto run(const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
|
|
buffer += "\n";
|
|
append(buffer, on_append_tick);
|
|
buffer.clear();
|
|
return LM::Inference::run("\n", on_generation_tick);
|
|
}
|
|
void example(std::string_view response = "") {
|
|
buffer += "\n";
|
|
buffer += response;
|
|
if (!response.empty()) buffer += "\n";
|
|
}
|
|
|
|
void create_savestate(Savestate &sv, const std::function<bool (float)> &on_append_tick = nullptr) {
|
|
if (!buffer.empty()) {
|
|
append(buffer, on_append_tick);
|
|
buffer.clear();
|
|
}
|
|
LM::Inference::create_savestate(sv);
|
|
}
|
|
void restore_savestate(const Savestate &sv) {
|
|
LM::Inference::restore_savestate(sv);
|
|
}
|
|
};
|
|
|
|
class Dictionary : PyEval {
|
|
static inline auto word_lookup_exprgen(const std::string& escaped_word) {
|
|
return "dict.word_lookup("+escaped_word+")";
|
|
}
|
|
|
|
public:
|
|
constexpr static LM::Inference::Params get_params() {
|
|
auto p = get_recommended_params();
|
|
p.top_k = 5;
|
|
p.top_p = 0.2f;
|
|
p.temp = 0.2f;
|
|
return p;
|
|
}
|
|
|
|
Dictionary(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) {
|
|
begin()
|
|
.load_module("huge_dictionary", "dict")
|
|
.example("");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"Treehouse\"")+".description")
|
|
.example("'A small house, especially one for children to play in, built or placed up in the branches of a tree.'");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"Python\"")+".description")
|
|
.example("'Any of several Old World boa constrictors of the subfamily Pythoninae, often growing to a length of more than 20 feet (6 meters): the Indian python, Python molurus, is endangered.'");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"C\"")+".description")
|
|
.example("'The 3rd letter of the alphabet.'");
|
|
begin()
|
|
.expression(word_lookup_exprgen("\"Appletree\"")+".syllables")
|
|
.example("['Ap', 'ple', 'tree']");
|
|
}
|
|
|
|
std::string lookup(std::string_view word, const std::string& what, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
|
|
return begin().expression(word_lookup_exprgen(escape(word))+'.'+what)
|
|
.run(on_append_tick, on_generation_tick);
|
|
}
|
|
};
|
|
|
|
class Translator : PyEval {
|
|
LM::Inference::Savestate sv;
|
|
std::unordered_map<size_t/*hash*/, std::string> cache;
|
|
|
|
static inline auto translation_exprgen(const std::string& escaped_text, const std::string& escaped_language) {
|
|
return "translator.translate("+escaped_text+", "+escaped_language+")";
|
|
}
|
|
|
|
public:
|
|
constexpr static LM::Inference::Params get_params() {
|
|
auto p = get_recommended_params();
|
|
p.top_k = 5;
|
|
p.top_p = 0.2f;
|
|
p.temp = 0.1f;
|
|
return p;
|
|
}
|
|
|
|
Translator(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) {
|
|
begin()
|
|
.load_module("deepl_scrape", "translator")
|
|
.example("");
|
|
begin()
|
|
.expression("# Using translator.translate, we can reliably translate any language into any other language. The meaning of the text is going to stay the same and if the text already is in the target language, it remains unchanged.")
|
|
.example("");
|
|
begin()
|
|
.expression(translation_exprgen("\"Treehouse\"", "\"DE\""))
|
|
.example("'Baumhaus'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Mir fiel ein Ball auf den Fuss!\"", "\"EN\""))
|
|
.example("'A ball fell onto my foot!'");
|
|
begin()
|
|
.expression(translation_exprgen("\"He's not an NPC, he's just another user here :/\"", "\"DE\""))
|
|
.example("'Er ist kein NPC, er ist hier nur ein normaler Benutzer :/'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Du bist ein sehr freundlicher Mensch :-)\"", "\"EN\""))
|
|
.example("'You are a very kind human :-)'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Hi\"", "\"EN\""))
|
|
.example("'Hi'");
|
|
begin()
|
|
.expression(translation_exprgen("\"What is the root of nine\"", "\"DE\""))
|
|
.example("'Was ist die Wurzel von neun'");
|
|
begin()
|
|
.expression(translation_exprgen("\"How long until school starts?\"", "\"IT\""))
|
|
.example("'Quanto manca all'inizio della scuola?'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Poisoning someone else means killing him/her, so you shouldn't ever try that!\"", "\"DE\""))
|
|
.example("'Jemanden zu vergiften bedeutet, ihn/sie zu töten, also solltest du das niemals versuchen!'");
|
|
begin()
|
|
.expression(translation_exprgen("\"Please wait...\"", "\"DE\""))
|
|
.example("'Bitte warten...'");
|
|
}
|
|
|
|
std::string translate(std::string_view text, std::string_view language, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
|
|
// Hash
|
|
size_t hash = std::hash<std::string_view>{}(text) + std::hash<std::string_view>{}(language);
|
|
|
|
// Cache lookup
|
|
auto res = cache.find(hash);
|
|
if (res != cache.end()) {
|
|
return res->second;
|
|
}
|
|
|
|
// Restore savestate
|
|
if (sv.is_valid()) restore_savestate(sv);
|
|
else create_savestate(sv, on_append_tick);
|
|
|
|
// Run inference
|
|
auto fres = unescape(begin().expression(translation_exprgen(escape(text), escape(language)))
|
|
.run(on_append_tick, on_generation_tick));
|
|
|
|
// Add to cache
|
|
cache[hash] = fres;
|
|
|
|
// Return final result;
|
|
return fres;
|
|
}
|
|
};
|
|
#endif // ANYPROC_HPP
|