diff --git a/src/class_diagram/generators/json/class_diagram_generator.cc b/src/class_diagram/generators/json/class_diagram_generator.cc index 86403cd6..a15c88b8 100644 --- a/src/class_diagram/generators/json/class_diagram_generator.cc +++ b/src/class_diagram/generators/json/class_diagram_generator.cc @@ -86,7 +86,7 @@ void to_json(nlohmann::json &j, const class_ &c) j["methods"] = c.methods(); j["bases"] = c.parents(); - j["template_parameters"] = c.templates(); + j["template_parameters"] = c.template_params(); } void to_json(nlohmann::json &j, const enum_ &c) diff --git a/src/class_diagram/generators/plantuml/class_diagram_generator.cc b/src/class_diagram/generators/plantuml/class_diagram_generator.cc index 17e2e296..76917f83 100644 --- a/src/class_diagram/generators/plantuml/class_diagram_generator.cc +++ b/src/class_diagram/generators/plantuml/class_diagram_generator.cc @@ -166,7 +166,7 @@ void generator::generate(const class_ &c, std::ostream &ostr) const ostr << plantuml_common::to_plantuml(m.access()) << m.name(); - if (!m.templates().empty()) { + if (!m.template_params().empty()) { m.render_template_params(ostr, m_config.using_namespace(), false); } diff --git a/src/class_diagram/model/class.cc b/src/class_diagram/model/class.cc index b97e702d..93c9d5e4 100644 --- a/src/class_diagram/model/class.cc +++ b/src/class_diagram/model/class.cc @@ -111,15 +111,13 @@ bool class_::is_abstract() const [](const auto &method) { return method.is_pure_virtual(); }); } -int class_::calculate_template_specialization_match( - const class_ &other, const std::string &full_name) const +int class_::calculate_template_specialization_match(const class_ &other) const { int res{0}; const std::string left = name_and_ns(); // TODO: handle variadic templates - if ((left != full_name) || - (templates().size() != other.templates().size())) { + if (left != other.name_and_ns()) { return res; } diff --git a/src/class_diagram/model/class.h b/src/class_diagram/model/class.h index e9498a94..143b0645 100644 --- a/src/class_diagram/model/class.h +++ b/src/class_diagram/model/class.h @@ -70,8 +70,7 @@ public: bool is_abstract() const; - int calculate_template_specialization_match( - const class_ &other, const std::string &full_name) const; + int calculate_template_specialization_match(const class_ &other) const; private: bool is_struct_{false}; diff --git a/src/class_diagram/visitor/translation_unit_visitor.cc b/src/class_diagram/visitor/translation_unit_visitor.cc index 77fc8bae..f60b474a 100644 --- a/src/class_diagram/visitor/translation_unit_visitor.cc +++ b/src/class_diagram/visitor/translation_unit_visitor.cc @@ -1884,7 +1884,8 @@ void translation_unit_visitor::process_template_specialization_argument( argument->set_id(nested_template_instantiation->id()); - for (const auto &t : nested_template_instantiation->templates()) + for (const auto &t : + nested_template_instantiation->template_params()) argument->add_template_param(t); } else if (arg.getAsType()->getAs() != @@ -2118,18 +2119,19 @@ std::unique_ptr translation_unit_visitor:: int best_match{}; common::model::diagram_element::id_t best_match_id{0}; - for (const auto c : diagram().classes()) { - if (c.get() == template_instantiation) + for (const auto templ : diagram().classes()) { + if (templ.get() == template_instantiation) continue; - auto c_full_name = c.get().full_name(false); - auto match = c.get().calculate_template_specialization_match( - template_instantiation, template_instantiation.name_and_ns()); + auto c_full_name = templ.get().full_name(false); + auto match = + template_instantiation.calculate_template_specialization_match( + templ.get()); if (match > best_match) { best_match = match; best_match_full_name = c_full_name; - best_match_id = c.get().id(); + best_match_id = templ.get().id(); } } @@ -2307,18 +2309,19 @@ std::unique_ptr translation_unit_visitor::build_template_instantiation( int best_match{}; common::model::diagram_element::id_t best_match_id{0}; - for (const auto c : diagram().classes()) { - if (c.get() == template_instantiation) + for (const auto templ : diagram().classes()) { + if (templ.get() == template_instantiation) continue; - auto c_full_name = c.get().full_name(false); - auto match = c.get().calculate_template_specialization_match( - template_instantiation, template_instantiation.name_and_ns()); + auto c_full_name = templ.get().full_name(false); + auto match = + template_instantiation.calculate_template_specialization_match( + templ.get()); if (match > best_match) { best_match = match; best_match_full_name = c_full_name; - best_match_id = c.get().id(); + best_match_id = templ.get().id(); } } @@ -2542,7 +2545,7 @@ translation_unit_visitor::build_template_instantiation_process_type_argument( argument.set_id(nested_template_instantiation->id()); - for (const auto &t : nested_template_instantiation->templates()) + for (const auto &t : nested_template_instantiation->template_params()) argument.add_template_param(t); // Check if this template should be simplified (e.g. system @@ -2804,7 +2807,7 @@ void translation_unit_visitor::process_field( if (template_field_type != nullptr) { // Skip types which are template template parameters of the parent // template - for (const auto &class_template_param : c.templates()) { + for (const auto &class_template_param : c.template_params()) { if (class_template_param.name() == template_field_type->getTemplateName() .getAsTemplateDecl() @@ -2851,7 +2854,7 @@ void translation_unit_visitor::process_field( found_relationships_t nested_relationships; if (!template_instantiation_added_as_aggregation) { for (const auto &template_argument : - template_specialization.templates()) { + template_specialization.template_params()) { LOG_DBG("Looking for nested relationships from {}::{} in " "template {}", diff --git a/src/common/model/template_parameter.cc b/src/common/model/template_parameter.cc index 39f84b4a..b2abc965 100644 --- a/src/common/model/template_parameter.cc +++ b/src/common/model/template_parameter.cc @@ -109,84 +109,36 @@ void template_parameter::is_variadic(bool is_variadic) noexcept bool template_parameter::is_variadic() const noexcept { return is_variadic_; } int template_parameter::calculate_specialization_match( - const template_parameter &ct) const + const template_parameter &base_template_parameter) const { int res{0}; - if (ct.type().has_value() && type().has_value() && - !ct.is_template_parameter() && !is_template_parameter()) { + if (base_template_parameter.type().has_value() && type().has_value() && + !base_template_parameter.is_template_parameter() && + !is_template_parameter()) { - if (ct.type().value() != type().value()) + if (base_template_parameter.type().value() != type().value()) return 0; else res++; } - if (ct.is_function_template() && !is_function_template()) + if (base_template_parameter.is_function_template() && + !is_function_template()) return 0; - if (template_params().size() > 0 && ct.template_params().size() > 0) { - // More generic template params - const auto &template_params = ct.template_params(); - const auto &specialization_params = this->template_params(); - auto template_index{0U}; - auto arg_index{0U}; + if (!base_template_parameter.template_params().empty() && + !template_params().empty()) { + auto params_match = calculate_template_params_specialization_match( + template_params(), base_template_parameter.template_params()); - while (arg_index < specialization_params.size() && - template_index < template_params.size()) { - auto match = specialization_params.at(arg_index) - .calculate_specialization_match( - template_params.at(template_index)); - - if (match == 0) { - return 0; - } - - if (!template_params.at(template_index).is_variadic()) - template_index++; - - res += match; - - // Add 1 point if the current specialization param is an argument - // as it's a more specific match than 2 template params - if (!specialization_params.at(arg_index).is_template_parameter()) - res++; - - arg_index++; - } - - if (arg_index == specialization_params.size()) { - // Check also backwards to make sure that trailing non-variadic - // params match after a variadic parameter - template_index = template_params.size() - 1; - arg_index = specialization_params.size() - 1; - - while (true) { - auto match = specialization_params.at(arg_index) - .calculate_specialization_match( - template_params.at(template_index)); - if (match == 0) { - return 0; - } - - if (arg_index == 0 || template_index == 0) - break; - - arg_index--; - - if (!template_params.at(template_index).is_variadic()) - template_index--; - else - break; - } - - return res; - } - else + if (params_match == 0) return 0; - } - if ((ct.is_template_parameter() || ct.is_template_template_parameter()) && + res += params_match; + } + else if ((base_template_parameter.is_template_parameter() || + base_template_parameter.is_template_template_parameter()) && !is_template_parameter()) return 1; @@ -372,4 +324,76 @@ const std::optional &template_parameter::concept_constraint() const return concept_constraint_; } +int calculate_template_params_specialization_match( + const std::vector &specialization_params, + const std::vector &template_params) +{ + int res{0}; + + if (!specialization_params.empty() && !template_params.empty()) { + auto template_index{0U}; + auto arg_index{0U}; + + while (arg_index < specialization_params.size() && + template_index < template_params.size()) { + auto match = specialization_params.at(arg_index) + .calculate_specialization_match( + template_params.at(template_index)); + + if (match == 0) { + return 0; + } + + // Add 1 point if the current specialization param is an argument + // as it's a more specific match than 2 template params + if (!specialization_params.at(arg_index).is_template_parameter()) + res++; + + // Add 1 point if the current template param is an argument + // as it's a more specific match than 2 template params + if (!template_params.at(template_index).is_template_parameter()) + res++; + + if (!template_params.at(template_index).is_variadic()) + template_index++; + + res += match; + + arg_index++; + } + + if (arg_index == specialization_params.size()) { + // Check also backwards to make sure that trailing non-variadic + // params match after a variadic parameter + template_index = template_params.size() - 1; + arg_index = specialization_params.size() - 1; + + while (true) { + auto match = specialization_params.at(arg_index) + .calculate_specialization_match( + template_params.at(template_index)); + if (match == 0) { + return 0; + } + + if (arg_index == 0 || template_index == 0) + break; + + arg_index--; + + if (!template_params.at(template_index).is_variadic()) + template_index--; + else + break; + } + + return res; + } + else + return 0; + } + + return 0; +} + } // namespace clanguml::common::model diff --git a/src/common/model/template_parameter.h b/src/common/model/template_parameter.h index 005e3688..d99056d7 100644 --- a/src/common/model/template_parameter.h +++ b/src/common/model/template_parameter.h @@ -122,7 +122,8 @@ public: void is_variadic(bool is_variadic) noexcept; bool is_variadic() const noexcept; - int calculate_specialization_match(const template_parameter &ct) const; + int calculate_specialization_match( + const template_parameter &base_template_parameter) const; friend bool operator==( const template_parameter &l, const template_parameter &r); @@ -221,4 +222,9 @@ private: bool is_unexposed_{false}; }; + +int calculate_template_params_specialization_match( + const std::vector &specialization, + const std::vector &base_template); + } // namespace clanguml::common::model diff --git a/src/common/model/template_trait.cc b/src/common/model/template_trait.cc index 63f8931d..c3f0b0bf 100644 --- a/src/common/model/template_trait.cc +++ b/src/common/model/template_trait.cc @@ -58,42 +58,16 @@ bool template_trait::is_implicit() const { return is_implicit_; } void template_trait::set_implicit(bool implicit) { is_implicit_ = implicit; } -const std::vector &template_trait::templates() const +const std::vector &template_trait::template_params() const { return templates_; } int template_trait::calculate_template_specialization_match( - const template_trait &other) const + const template_trait &base_template) const { - int res{0}; - - // Iterate over all template arguments - for (auto i = 0U; i < other.templates().size(); i++) { - const auto &template_arg = templates().at(i); - const auto &other_template_arg = other.templates().at(i); - - if (template_arg == other_template_arg) { - res++; - - if (!template_arg.is_template_parameter()) - res++; - - if (!other_template_arg.is_template_parameter()) - res++; - } - else if (auto match = other_template_arg.calculate_specialization_match( - template_arg); - match > 0) { - res += match; - } - else { - res = 0; - break; - } - } - - return res; + return calculate_template_params_specialization_match( + template_params(), base_template.template_params()); } } // namespace clanguml::common::model \ No newline at end of file diff --git a/src/common/model/template_trait.h b/src/common/model/template_trait.h index e0904e0d..1414b330 100644 --- a/src/common/model/template_trait.h +++ b/src/common/model/template_trait.h @@ -36,7 +36,7 @@ public: void add_template(template_parameter &&tmplt); - const std::vector &templates() const; + const std::vector &template_params() const; int calculate_template_specialization_match( const template_trait &other) const; diff --git a/src/sequence_diagram/visitor/translation_unit_visitor.cc b/src/sequence_diagram/visitor/translation_unit_visitor.cc index f8ea54f5..6ac6cdcd 100644 --- a/src/sequence_diagram/visitor/translation_unit_visitor.cc +++ b/src/sequence_diagram/visitor/translation_unit_visitor.cc @@ -1715,7 +1715,8 @@ void translation_unit_visitor::process_template_specialization_argument( argument.set_id(nested_template_instantiation->id()); - for (const auto &t : nested_template_instantiation->templates()) + for (const auto &t : + nested_template_instantiation->template_params()) argument.add_template_param(t); // Check if this template should be simplified (e.g. system @@ -2009,8 +2010,8 @@ translation_unit_visitor::build_template_instantiation( auto c_full_name = participant_as_class->full_name(false); auto match = - participant_as_class->calculate_template_specialization_match( - template_instantiation); + template_instantiation.calculate_template_specialization_match( + *participant_as_class); if (match > best_match) { best_match = match; diff --git a/tests/test_model.cc b/tests/test_model.cc index 49611605..ae6a7b06 100644 --- a/tests/test_model.cc +++ b/tests/test_model.cc @@ -19,6 +19,7 @@ #include "catch.h" +#include "class_diagram/model/class.h" #include "common/model/namespace.h" #include "common/model/template_parameter.h" @@ -72,6 +73,48 @@ TEST_CASE("Test namespace_", "[unit-test]") CHECK(ns8.relative(name) == "ccc>"); } +TEST_CASE("Test class_::calculate_specialization_match", "[unit-test]") +{ + using clanguml::class_diagram::model::class_; + using clanguml::common::model::template_parameter; + + { + auto c = class_({}); + c.set_name("A"); + c.add_template(template_parameter::make_argument("int")); + c.add_template(template_parameter::make_argument("double")); + c.add_template(template_parameter::make_argument("int")); + + auto t = class_({}); + t.set_name("A"); + t.add_template( + template_parameter::make_template_type("Args", {}, true)); + t.add_template(template_parameter::make_argument("int")); + + CHECK(c.calculate_template_specialization_match(t)); + } + + { + auto c = class_({}); + c.set_name("A"); + c.add_template(template_parameter::make_argument("double")); + c.add_template(template_parameter::make_argument("int")); + + auto s = class_({}); + s.set_name("A"); + s.add_template(template_parameter::make_argument("double")); + s.add_template(template_parameter::make_template_type("V")); + + auto t = class_({}); + t.set_name("A"); + t.add_template(template_parameter::make_template_type("T")); + t.add_template(template_parameter::make_template_type("V")); + + CHECK(c.calculate_template_specialization_match(s) > + c.calculate_template_specialization_match(t)); + } +} + TEST_CASE( "Test template_parameter::calculate_specialization_match", "[unit-test]") { @@ -81,7 +124,7 @@ TEST_CASE( auto tp1 = template_parameter::make_template_type("T"); auto tp2 = template_parameter::make_argument("int"); - CHECK(tp2.calculate_specialization_match(tp1)); + CHECK(tp2.calculate_specialization_match(tp1) > 0); } { @@ -150,7 +193,7 @@ TEST_CASE( tp2.add_template_param(template_parameter::make_argument("int")); tp2.add_template_param(template_parameter::make_argument("double")); - CHECK(!tp2.calculate_specialization_match(tp1)); + CHECK(tp2.calculate_specialization_match(tp1) == 0); } { @@ -163,7 +206,7 @@ TEST_CASE( tp2.add_template_param(template_parameter::make_argument("int")); tp2.add_template_param(template_parameter::make_argument("double")); - CHECK(!tp2.calculate_specialization_match(tp1)); + CHECK(tp2.calculate_specialization_match(tp1) == 0); } {