1
0
Fork 0
mirror of https://gitlab.com/niansa/anyproc.git synced 2025-03-06 20:49:24 +01:00
anyproc/pybind.cpp

29 lines
1.8 KiB
C++

#include "anyproc.hpp"
#define PYBIND11_DETAILED_ERROR_MESSAGES
#include <justlm.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
namespace py = pybind11;
PYBIND11_MODULE(anyproc_py, m) {
py::class_<PyEval>(m, "PyEval")
.def(py::init<const std::string&, const LM::Inference::Params&>(), py::arg("weights_path"), py::arg("params") = PyEval::get_recommended_params())
.def("begin", &PyEval::begin, py::return_value_policy::reference_internal)
.def("load_module", py::overload_cast<std::string_view>(&PyEval::load_module), py::arg("name"), py::return_value_policy::reference_internal)
.def("load_module", py::overload_cast<std::string_view, std::string_view>(&PyEval::load_module), py::arg("name"), py::arg("alias"), py::return_value_policy::reference_internal)
.def("expression", &PyEval::expression, py::return_value_policy::reference_internal)
.def("run", &PyEval::run, py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr)
.def("example", &PyEval::example, py::arg("response") = "")
.def("create_savestate", &PyEval::create_savestate, py::arg("sv_ref"), py::arg("on_append_tick") = nullptr)
.def("restore_savestate", &PyEval::restore_savestate, py::arg("sv_ref"));
py::class_<Dictionary>(m, "Dictionary")
.def(py::init<const std::string&>(), py::arg("weights_path"))
.def("lookup", &Dictionary::lookup, py::arg("word"), py::arg("what"), py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr);
py::class_<Translator>(m, "Translator")
.def(py::init<const std::string&>(), py::arg("weights_path"))
.def("translate", &Translator::translate, py::arg("text"), py::arg("language"), py::arg("on_append_tick") = nullptr, py::arg("on_generation_tick") = nullptr);
}