1
0
Fork 0
mirror of https://gitlab.com/niansa/anyproc.git synced 2025-03-06 20:49:24 +01:00

Added simple Python bindings

This commit is contained in:
niansa/tuxifan 2023-04-26 11:00:35 +02:00
parent 7ea62dcd2b
commit b4ba530847
3 changed files with 40 additions and 1 deletions

View file

@ -7,6 +7,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_subdirectory(libjustlm)
set(ANYPROC_PYBIND No CACHE BOOL "If python bindings should be build")
add_library(anyproc INTERFACE)
target_include_directories(anyproc INTERFACE include/)
target_link_libraries(anyproc INTERFACE libjustlm)
@ -17,5 +19,13 @@ target_link_libraries(dictionary PUBLIC anyproc)
add_executable(translator translator.cpp)
target_link_libraries(translator PUBLIC anyproc)
if (ANYPROC_PYBIND)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)
pybind11_add_module(anyproc_py pybind.cpp)
target_link_libraries(anyproc_py PRIVATE anyproc)
endif()
install(TARGETS anyproc
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

View file

@ -87,7 +87,7 @@ public:
buffer.clear();
return LM::Inference::run("\n", on_generation_tick);
}
void example(std::string_view response) {
void example(std::string_view response = "") {
buffer += "\n";
buffer += response;
if (!response.empty()) buffer += "\n";

29
pybind.cpp Normal file
View file

@ -0,0 +1,29 @@
#include "anyproc.hpp"
#define PYBIND11_DETAILED_ERROR_MESSAGES
#include <justlm.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
namespace py = pybind11;
PYBIND11_MODULE(anyproc_py, m) {
py::class_<PyEval>(m, "PyEval")
.def(py::init<const std::string&, const LM::Inference::Params&>(), py::arg("weights_path"), py::arg("params") = PyEval::get_recommended_params())
.def("begin", &PyEval::begin, py::return_value_policy::reference_internal)
.def("load_module", py::overload_cast<std::string_view>(&PyEval::load_module), py::arg("name"), py::return_value_policy::reference_internal)
.def("load_module", py::overload_cast<std::string_view, std::string_view>(&PyEval::load_module), py::arg("name"), py::arg("alias"), py::return_value_policy::reference_internal)
.def("expression", &PyEval::expression, py::return_value_policy::reference_internal)
.def("run", &PyEval::run, py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr)
.def("example", &PyEval::example, py::arg("response") = "")
.def("create_savestate", &PyEval::create_savestate, py::arg("sv_ref"), py::arg("on_append_tick") = nullptr)
.def("restore_savestate", &PyEval::restore_savestate, py::arg("sv_ref"));
py::class_<Dictionary>(m, "Dictionary")
.def(py::init<const std::string&>(), py::arg("weights_path"))
.def("lookup", &Dictionary::lookup, py::arg("word"), py::arg("what"), py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr);
py::class_<Translator>(m, "Translator")
.def(py::init<const std::string&>(), py::arg("weights_path"))
.def("translate", &Translator::translate, py::arg("text"), py::arg("language"), py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr);
}