1
0
Fork 0
mirror of https://gitlab.com/niansa/anyproc.git synced 2025-03-06 20:49:24 +01:00

Added CoSched support

This commit is contained in:
niansa/tuxifan 2023-05-04 15:23:16 +02:00
parent c90cea8ac8
commit a184fef764
2 changed files with 31 additions and 18 deletions

View file

@ -6,6 +6,11 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(ANYPROC_PYBIND No CACHE BOOL "If python bindings should be build")
set(ANYPROC_COSCHED No CACHE BOOL "If CoSched should be made use of")
if (ANYPROC_COSCHED)
set(CMAKE_CXX_STANDARD 20)
endif()
add_library(anyproc INTERFACE)
target_include_directories(anyproc INTERFACE include/)
@ -17,11 +22,19 @@ target_link_libraries(dictionary PUBLIC anyproc)
add_executable(translator translator.cpp)
target_link_libraries(translator PUBLIC anyproc)
if (ANYPROC_COSCHED)
target_link_libraries(anyproc INTERFACE cosched)
endif()
if (ANYPROC_PYBIND)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)
pybind11_add_module(anyproc_py pybind.cpp)
target_link_libraries(anyproc_py PRIVATE anyproc)
if (ANYPROC_COSCHED)
message(FATAL_ERROR "Pybind can't be enabled in combination with CoSched")
endif()
endif()
install(TARGETS anyproc

View file

@ -103,11 +103,11 @@ public:
return *this;
}
auto run(const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
LM_SCHEDULABLE(std::string) run(const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
buffer += "\n";
i->append(buffer, on_append_tick);
LM_COAWAIT i->append(buffer, on_append_tick);
buffer.clear();
return i->run("\n", on_generation_tick);
LM_CORETURN LM_COAWAIT i->run("\n", on_generation_tick);
}
void example(std::string_view response = "") {
buffer += "\n";
@ -115,15 +115,15 @@ public:
if (!response.empty()) buffer += "\n";
}
void create_savestate(LM::Inference::Savestate &sv, const std::function<bool (float)> &on_append_tick = nullptr) {
LM_SCHEDULABLE(void) create_savestate(LM::Inference::Savestate &sv, const std::function<bool (float)> &on_append_tick = nullptr) {
if (!buffer.empty()) {
i->append(buffer, on_append_tick);
LM_COAWAIT i->append(buffer, on_append_tick);
buffer.clear();
}
i->create_savestate(sv);
LM_COAWAIT i->create_savestate(sv);
}
void restore_savestate(const LM::Inference::Savestate &sv) {
i->restore_savestate(sv);
LM_SCHEDULABLE(void) restore_savestate(const LM::Inference::Savestate &sv) {
LM_COAWAIT i->restore_savestate(sv);
}
LM::Inference& get_inference() {
@ -163,9 +163,9 @@ public:
.example("['Ap', 'ple', 'tree']");
}
std::string lookup(std::string_view word, const std::string& what, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
return begin().expression(word_lookup_exprgen(escape(word))+'.'+what)
.run(on_append_tick, on_generation_tick);
LM_SCHEDULABLE(std::string) lookup(std::string_view word, const std::string& what, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
LM_CORETURN LM_COAWAIT begin().expression(word_lookup_exprgen(escape(word))+'.'+what)
.run(on_append_tick, on_generation_tick);
}
};
@ -223,29 +223,29 @@ public:
.example("'Bitte warten...'");
}
std::string translate(std::string_view text, std::string_view language, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
LM_SCHEDULABLE(std::string) translate(std::string_view text, std::string_view language, const std::function<bool (float)> &on_append_tick = nullptr, const std::function<bool (const char *generated)>& on_generation_tick = nullptr) {
// Hash
size_t hash = std::hash<std::string_view>{}(text) + std::hash<std::string_view>{}(language);
// Cache lookup
auto res = cache.find(hash);
if (res != cache.end()) {
return res->second;
LM_CORETURN res->second;
}
// Restore savestate
if (sv.is_valid()) restore_savestate(sv);
else create_savestate(sv, on_append_tick);
if (sv.is_valid()) LM_COAWAIT restore_savestate(sv);
else LM_COAWAIT create_savestate(sv, on_append_tick);
// Run inference
auto fres = unescape(begin().expression(translation_exprgen(escape(text), escape(language)))
.run(on_append_tick, on_generation_tick));
auto fres = unescape(LM_COAWAIT begin().expression(translation_exprgen(escape(text), escape(language)))
.run(on_append_tick, on_generation_tick));
// Add to cache
cache[hash] = fres;
// Return final result;
return fres;
LM_CORETURN fres;
}
};
#endif // ANYPROC_HPP