1
0
Fork 0
mirror of https://gitlab.com/niansa/anyproc.git synced 2025-03-06 20:49:24 +01:00

Allow specifying custom parameters

This commit is contained in:
niansa 2023-04-23 15:31:50 +02:00
parent b4294524d2
commit bee9170bca
2 changed files with 16 additions and 16 deletions

View file

@ -44,7 +44,7 @@ public:
return fres;
}
static inline const LM::Inference::Params get_recommended_params() {
constexpr static LM::Inference::Params get_recommended_params() {
LM::Inference::Params p;
p.use_mlock = false;
p.repeat_penalty = 1.2f;
@ -106,7 +106,12 @@ public:
};
class Dictionary : PyEval {
static inline LM::Inference::Params get_params() {
static inline auto word_lookup_exprgen(const std::string& escaped_word) {
return "dict.word_lookup("+escaped_word+")";
}
public:
constexpr static LM::Inference::Params get_params() {
auto p = get_recommended_params();
p.top_k = 5;
p.top_p = 0.2f;
@ -114,12 +119,7 @@ class Dictionary : PyEval {
return p;
}
static inline auto word_lookup_exprgen(const std::string& escaped_word) {
return "dict.word_lookup("+escaped_word+")";
}
public:
Dictionary(const std::string& weights_path) : PyEval(weights_path, get_params()) {
Dictionary(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) {
begin()
.load_module("huge_dictionary", "dict")
.example("");
@ -147,7 +147,12 @@ class Translator : PyEval {
LM::Inference::Savestate sv;
std::unordered_map<size_t/*hash*/, std::string> cache;
static inline LM::Inference::Params get_params() {
static inline auto translation_exprgen(const std::string& escaped_text, const std::string& escaped_language) {
return "translator.translate("+escaped_text+", "+escaped_language+")";
}
public:
constexpr static LM::Inference::Params get_params() {
auto p = get_recommended_params();
p.top_k = 5;
p.top_p = 0.2f;
@ -155,12 +160,7 @@ class Translator : PyEval {
return p;
}
static inline auto translation_exprgen(const std::string& escaped_text, const std::string& escaped_language) {
return "translator.translate("+escaped_text+", "+escaped_language+")";
}
public:
Translator(const std::string& weights_path) : PyEval(weights_path, get_params()) {
Translator(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) {
begin()
.load_module("deepl_scrape", "translator")
.example("");

@ -1 +1 @@
Subproject commit d09f892120663c812fb36d9260b5505d4c3b93e8
Subproject commit 0466774286980697b843bd5a5b9382771de22887