diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 14470b7..b7af600 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -11,19 +11,19 @@ include(FetchContent) FetchContent_Declare( rds2cpp GIT_REPOSITORY https://github.com/LTLA/rds2cpp - GIT_TAG master + GIT_TAG v1.3.1 ) FetchContent_Declare( byteme GIT_REPOSITORY https://github.com/LTLA/byteme - GIT_TAG master + GIT_TAG v2.1.3 ) FetchContent_Declare( sanisizer GIT_REPOSITORY https://github.com/LTLA/sanisizer - GIT_TAG master + GIT_TAG v0.2.2 ) FetchContent_MakeAvailable(byteme) diff --git a/lib/src/rdswrapper.cpp b/lib/src/rdswrapper.cpp index 2ed6aa2..a7f2ab8 100644 --- a/lib/src/rdswrapper.cpp +++ b/lib/src/rdswrapper.cpp @@ -10,10 +10,12 @@ namespace py = pybind11; class RdsReader { private: const rds2cpp::RObject* ptr; + const std::vector* symbols_ptr; public: - RdsReader(const rds2cpp::RObject* p) : ptr(p) { + RdsReader(const rds2cpp::RObject* p, const std::vector* syms) : ptr(p), symbols_ptr(syms) { if (!p) throw std::runtime_error("Null pointer passed to 'RdsReader'."); + if (!syms) throw std::runtime_error("Null symbols pointer passed to 'RdsReader'."); } std::string get_rtype() const { @@ -69,23 +71,36 @@ class RdsReader { throw std::runtime_error("Invalid type for 'string_arr'"); } const auto& data = static_cast(ptr)->data; - return py::cast(data); + py::list result; + for (const auto& s : data) { + if (s.value.has_value()) { + result.append(s.value.value()); + } else { + result.append(py::none()); + } + } + return result; } py::list get_attribute_names() const { if (!ptr) throw std::runtime_error("Null pointer in 'get_attribute_names'"); - return py::cast(get_attributes().names); + const auto& attrs = get_attributes(); + py::list names; + for (const auto& attr : attrs) { + names.append(resolve_symbol(attr.name)); + } + return names; } py::object load_attribute_by_name(const std::string& name) const { if (!ptr) throw std::runtime_error("Null pointer in 'load_attribute_by_name'"); - const auto& attributes = get_attributes(); - auto it = std::find(attributes.names.begin(), attributes.names.end(), name); - if (it == attributes.names.end()) { - throw std::runtime_error("Attribute not found: " + name); + const auto& attrs = get_attributes(); + for (const auto& attr : attrs) { + if (resolve_symbol(attr.name) == name) { + return py::cast(new RdsReader(attr.value.get(), symbols_ptr)); + } } - size_t index = std::distance(attributes.names.begin(), it); - return py::cast(new RdsReader(attributes.values[index].get())); + throw std::runtime_error("Attribute not found: " + name); } py::object load_vec_element(int index) const { @@ -97,7 +112,7 @@ class RdsReader { if (index < 0 || static_cast(index) >= data.size()) { throw std::out_of_range("Vector index out of range"); } - return py::cast(new RdsReader(data[index].get())); + return py::cast(new RdsReader(data[index].get(), symbols_ptr)); } std::string get_package_name() const { @@ -126,7 +141,14 @@ class RdsReader { } private: - const rds2cpp::Attributes& get_attributes() const { + std::string resolve_symbol(const rds2cpp::SymbolIndex& sym) const { + if (sym.index >= symbols_ptr->size()) { + throw std::runtime_error("Symbol index out of range"); + } + return (*symbols_ptr)[sym.index].name; + } + + const std::vector& get_attributes() const { if (!ptr) throw std::runtime_error("Null pointer in get_attributes"); switch (ptr->type()) { case rds2cpp::SEXPType::INT: return static_cast(ptr)->attributes; @@ -153,7 +175,7 @@ class RdsObject { if (!parsed || !parsed->object) { throw std::runtime_error("Failed to parse RDS file"); } - reader = std::make_unique(parsed->object.get()); + reader = std::make_unique(parsed->object.get(), &parsed->symbols); } catch (const std::exception& e) { throw std::runtime_error(std::string("Error in 'RdsObject' constructor: ") + e.what()); } @@ -181,11 +203,10 @@ class RdaObject { py::list get_object_names() const { if (!parsed) throw std::runtime_error("Null parsed in 'get_object_names'"); - const auto& pairlist = parsed->contents; py::list names; - for (size_t i = 0; i < pairlist.tag_names.size(); ++i) { - if (pairlist.has_tag[i]) { - names.append(pairlist.tag_names[i]); + for (const auto& obj : parsed->objects) { + if (obj.name.index < parsed->symbols.size()) { + names.append(parsed->symbols[obj.name.index].name); } else { names.append(py::none()); } @@ -195,24 +216,23 @@ class RdaObject { int get_object_count() const { if (!parsed) throw std::runtime_error("Null parsed in 'get_object_count'"); - return static_cast(parsed->contents.data.size()); + return static_cast(parsed->objects.size()); } RdsReader* get_object_by_index(int index) const { if (!parsed) throw std::runtime_error("Null parsed in 'get_object_by_index'"); - const auto& data = parsed->contents.data; - if (index < 0 || static_cast(index) >= data.size()) { + if (index < 0 || static_cast(index) >= parsed->objects.size()) { throw std::out_of_range("Object index out of range"); } - return new RdsReader(data[index].get()); + return new RdsReader(parsed->objects[index].value.get(), &parsed->symbols); } RdsReader* get_object_by_name(const std::string& name) const { if (!parsed) throw std::runtime_error("Null parsed in 'get_object_by_name'"); - const auto& pairlist = parsed->contents; - for (size_t i = 0; i < pairlist.tag_names.size(); ++i) { - if (pairlist.has_tag[i] && pairlist.tag_names[i] == name) { - return new RdsReader(pairlist.data[i].get()); + for (const auto& obj : parsed->objects) { + if (obj.name.index < parsed->symbols.size() && + parsed->symbols[obj.name.index].name == name) { + return new RdsReader(obj.value.get(), &parsed->symbols); } } throw std::runtime_error("Object not found: " + name); @@ -234,7 +254,6 @@ PYBIND11_MODULE(lib_rds_parser, m) { .def("get_object_by_name", &RdaObject::get_object_by_name, py::return_value_policy::take_ownership, py::keep_alive<0, 1>()); py::class_(m, "RdsReader") - .def(py::init()) .def("get_rtype", &RdsReader::get_rtype) .def("get_rsize", &RdsReader::get_rsize) .def("get_numeric_data", &RdsReader::get_numeric_data)