mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Added no_translate model config option
This commit is contained in:
parent
9cead67daf
commit
e7aff482fc
2 changed files with 11 additions and 7 deletions
|
@ -3,3 +3,4 @@ instruct_mode_policy force
|
||||||
user_prompt USER:
|
user_prompt USER:
|
||||||
bot_prompt ASSISTANT:
|
bot_prompt ASSISTANT:
|
||||||
emits_eos true
|
emits_eos true
|
||||||
|
no_translate true
|
||||||
|
|
17
main.cpp
17
main.cpp
|
@ -83,7 +83,8 @@ public:
|
||||||
std::string weight_path,
|
std::string weight_path,
|
||||||
user_prompt,
|
user_prompt,
|
||||||
bot_prompt;
|
bot_prompt;
|
||||||
bool emits_eos = false;
|
bool emits_eos = false,
|
||||||
|
no_translate = false;
|
||||||
enum class InstructModePolicy {
|
enum class InstructModePolicy {
|
||||||
Allow = 0b11,
|
Allow = 0b11,
|
||||||
Force = 0b10,
|
Force = 0b10,
|
||||||
|
@ -143,10 +144,10 @@ private:
|
||||||
# define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0
|
# define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0
|
||||||
|
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
std::string_view llm_translate_to_en(std::string_view text) {
|
std::string_view llm_translate_to_en(std::string_view text, bool skip = false) {
|
||||||
ENSURE_LLM_THREAD();
|
ENSURE_LLM_THREAD();
|
||||||
// Skip if there is no translator
|
// Skip if there is no translator
|
||||||
if (translator == nullptr) {
|
if (translator == nullptr || skip) {
|
||||||
std::cout << "(" << language << ") " << text << std::endl;
|
std::cout << "(" << language << ") " << text << std::endl;
|
||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
@ -169,10 +170,10 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
std::string_view llm_translate_from_en(std::string_view text) {
|
std::string_view llm_translate_from_en(std::string_view text, bool skip = false) {
|
||||||
ENSURE_LLM_THREAD();
|
ENSURE_LLM_THREAD();
|
||||||
// Skip if there is no translator
|
// Skip if there is no translator
|
||||||
if (translator == nullptr) {
|
if (translator == nullptr || skip) {
|
||||||
std::cout << "(" << language << ") " << text << std::endl;
|
std::cout << "(" << language << ") " << text << std::endl;
|
||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
@ -336,7 +337,7 @@ private:
|
||||||
for (const auto line : str_split(msg.content, '\n')) {
|
for (const auto line : str_split(msg.content, '\n')) {
|
||||||
Timer timeout;
|
Timer timeout;
|
||||||
bool timeout_exceeded = false;
|
bool timeout_exceeded = false;
|
||||||
inference.append(prefix+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
|
inference.append(prefix+std::string(llm_translate_to_en(line, channel_cfg.model_config->no_translate))+'\n', [&] (float progress) {
|
||||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||||
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
|
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
|
||||||
timeout_exceeded = true;
|
timeout_exceeded = true;
|
||||||
|
@ -401,7 +402,7 @@ private:
|
||||||
output = texts.timeout;
|
output = texts.timeout;
|
||||||
}
|
}
|
||||||
// Send resulting message
|
// Send resulting message
|
||||||
msg.content = llm_translate_from_en(output);
|
msg.content = llm_translate_from_en(output, channel_cfg.model_config->no_translate);
|
||||||
bot.message_edit(msg);
|
bot.message_edit(msg);
|
||||||
// Prepare for next message
|
// Prepare for next message
|
||||||
if (channel_cfg.model_config->emits_eos) {
|
if (channel_cfg.model_config->emits_eos) {
|
||||||
|
@ -724,6 +725,8 @@ int main(int argc, char **argv) {
|
||||||
model_cfg.instruct_mode_policy = parse_instruct_mode_policy(value);
|
model_cfg.instruct_mode_policy = parse_instruct_mode_policy(value);
|
||||||
} else if (key == "emits_eos") {
|
} else if (key == "emits_eos") {
|
||||||
model_cfg.emits_eos = parse_bool(value);
|
model_cfg.emits_eos = parse_bool(value);
|
||||||
|
} else if (key == "no_translate") {
|
||||||
|
model_cfg.no_translate = parse_bool(value);
|
||||||
} else if (!key.empty() && key[0] != '#') {
|
} else if (!key.empty() && key[0] != '#') {
|
||||||
std::cerr << "Error: Failed to parse model configuration file: Unknown key: " << key << std::endl;
|
std::cerr << "Error: Failed to parse model configuration file: Unknown key: " << key << std::endl;
|
||||||
exit(-3);
|
exit(-3);
|
||||||
|
|
Loading…
Add table
Reference in a new issue