From b4ba5308476a460d48f6b87b56a4cbc0fcbccb85 Mon Sep 17 00:00:00 2001 From: niansa Date: Wed, 26 Apr 2023 11:00:35 +0200 Subject: [PATCH] Added simple Python bindings --- CMakeLists.txt | 10 ++++++++++ include/anyproc.hpp | 2 +- pybind.cpp | 29 +++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 pybind.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f7d8694..896cffc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) add_subdirectory(libjustlm) +set(ANYPROC_PYBIND No CACHE BOOL "If python bindings should be build") + add_library(anyproc INTERFACE) target_include_directories(anyproc INTERFACE include/) target_link_libraries(anyproc INTERFACE libjustlm) @@ -17,5 +19,13 @@ target_link_libraries(dictionary PUBLIC anyproc) add_executable(translator translator.cpp) target_link_libraries(translator PUBLIC anyproc) +if (ANYPROC_PYBIND) + find_package(Python COMPONENTS Interpreter Development) + find_package(pybind11 CONFIG) + pybind11_add_module(anyproc_py pybind.cpp) + target_link_libraries(anyproc_py PRIVATE anyproc) +endif() + + install(TARGETS anyproc LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/include/anyproc.hpp b/include/anyproc.hpp index 73f968a..9fdd6cd 100644 --- a/include/anyproc.hpp +++ b/include/anyproc.hpp @@ -87,7 +87,7 @@ public: buffer.clear(); return LM::Inference::run("\n", on_generation_tick); } - void example(std::string_view response) { + void example(std::string_view response = "") { buffer += "\n"; buffer += response; if (!response.empty()) buffer += "\n"; diff --git a/pybind.cpp b/pybind.cpp new file mode 100644 index 0000000..04914e9 --- /dev/null +++ b/pybind.cpp @@ -0,0 +1,29 @@ +#include "anyproc.hpp" +#define PYBIND11_DETAILED_ERROR_MESSAGES +#include +#include +#include +#include + +namespace py = pybind11; + + + +PYBIND11_MODULE(anyproc_py, m) { + py::class_(m, "PyEval") + .def(py::init(), py::arg("weights_path"), py::arg("params") = PyEval::get_recommended_params()) + .def("begin", &PyEval::begin, py::return_value_policy::reference_internal) + .def("load_module", py::overload_cast(&PyEval::load_module), py::arg("name"), py::return_value_policy::reference_internal) + .def("load_module", py::overload_cast(&PyEval::load_module), py::arg("name"), py::arg("alias"), py::return_value_policy::reference_internal) + .def("expression", &PyEval::expression, py::return_value_policy::reference_internal) + .def("run", &PyEval::run, py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr) + .def("example", &PyEval::example, py::arg("response") = "") + .def("create_savestate", &PyEval::create_savestate, py::arg("sv_ref"), py::arg("on_append_tick") = nullptr) + .def("restore_savestate", &PyEval::restore_savestate, py::arg("sv_ref")); + py::class_(m, "Dictionary") + .def(py::init(), py::arg("weights_path")) + .def("lookup", &Dictionary::lookup, py::arg("word"), py::arg("what"), py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr); + py::class_(m, "Translator") + .def(py::init(), py::arg("weights_path")) + .def("translate", &Translator::translate, py::arg("text"), py::arg("language"), py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr); +}