Refactored template specialization matching

This commit is contained in:
Bartek Kryza
2023-04-05 21:57:44 +02:00
parent cb74864d0d
commit 38928cf86f
11 changed files with 174 additions and 126 deletions

View File

@@ -86,7 +86,7 @@ void to_json(nlohmann::json &j, const class_ &c)
j["methods"] = c.methods(); j["methods"] = c.methods();
j["bases"] = c.parents(); j["bases"] = c.parents();
j["template_parameters"] = c.templates(); j["template_parameters"] = c.template_params();
} }
void to_json(nlohmann::json &j, const enum_ &c) void to_json(nlohmann::json &j, const enum_ &c)

View File

@@ -166,7 +166,7 @@ void generator::generate(const class_ &c, std::ostream &ostr) const
ostr << plantuml_common::to_plantuml(m.access()) << m.name(); ostr << plantuml_common::to_plantuml(m.access()) << m.name();
if (!m.templates().empty()) { if (!m.template_params().empty()) {
m.render_template_params(ostr, m_config.using_namespace(), false); m.render_template_params(ostr, m_config.using_namespace(), false);
} }

View File

@@ -111,15 +111,13 @@ bool class_::is_abstract() const
[](const auto &method) { return method.is_pure_virtual(); }); [](const auto &method) { return method.is_pure_virtual(); });
} }
int class_::calculate_template_specialization_match( int class_::calculate_template_specialization_match(const class_ &other) const
const class_ &other, const std::string &full_name) const
{ {
int res{0}; int res{0};
const std::string left = name_and_ns(); const std::string left = name_and_ns();
// TODO: handle variadic templates // TODO: handle variadic templates
if ((left != full_name) || if (left != other.name_and_ns()) {
(templates().size() != other.templates().size())) {
return res; return res;
} }

View File

@@ -70,8 +70,7 @@ public:
bool is_abstract() const; bool is_abstract() const;
int calculate_template_specialization_match( int calculate_template_specialization_match(const class_ &other) const;
const class_ &other, const std::string &full_name) const;
private: private:
bool is_struct_{false}; bool is_struct_{false};

View File

@@ -1884,7 +1884,8 @@ void translation_unit_visitor::process_template_specialization_argument(
argument->set_id(nested_template_instantiation->id()); argument->set_id(nested_template_instantiation->id());
for (const auto &t : nested_template_instantiation->templates()) for (const auto &t :
nested_template_instantiation->template_params())
argument->add_template_param(t); argument->add_template_param(t);
} }
else if (arg.getAsType()->getAs<clang::TemplateTypeParmType>() != else if (arg.getAsType()->getAs<clang::TemplateTypeParmType>() !=
@@ -2118,18 +2119,19 @@ std::unique_ptr<class_> translation_unit_visitor::
int best_match{}; int best_match{};
common::model::diagram_element::id_t best_match_id{0}; common::model::diagram_element::id_t best_match_id{0};
for (const auto c : diagram().classes()) { for (const auto templ : diagram().classes()) {
if (c.get() == template_instantiation) if (templ.get() == template_instantiation)
continue; continue;
auto c_full_name = c.get().full_name(false); auto c_full_name = templ.get().full_name(false);
auto match = c.get().calculate_template_specialization_match( auto match =
template_instantiation, template_instantiation.name_and_ns()); template_instantiation.calculate_template_specialization_match(
templ.get());
if (match > best_match) { if (match > best_match) {
best_match = match; best_match = match;
best_match_full_name = c_full_name; best_match_full_name = c_full_name;
best_match_id = c.get().id(); best_match_id = templ.get().id();
} }
} }
@@ -2307,18 +2309,19 @@ std::unique_ptr<class_> translation_unit_visitor::build_template_instantiation(
int best_match{}; int best_match{};
common::model::diagram_element::id_t best_match_id{0}; common::model::diagram_element::id_t best_match_id{0};
for (const auto c : diagram().classes()) { for (const auto templ : diagram().classes()) {
if (c.get() == template_instantiation) if (templ.get() == template_instantiation)
continue; continue;
auto c_full_name = c.get().full_name(false); auto c_full_name = templ.get().full_name(false);
auto match = c.get().calculate_template_specialization_match( auto match =
template_instantiation, template_instantiation.name_and_ns()); template_instantiation.calculate_template_specialization_match(
templ.get());
if (match > best_match) { if (match > best_match) {
best_match = match; best_match = match;
best_match_full_name = c_full_name; best_match_full_name = c_full_name;
best_match_id = c.get().id(); best_match_id = templ.get().id();
} }
} }
@@ -2542,7 +2545,7 @@ translation_unit_visitor::build_template_instantiation_process_type_argument(
argument.set_id(nested_template_instantiation->id()); argument.set_id(nested_template_instantiation->id());
for (const auto &t : nested_template_instantiation->templates()) for (const auto &t : nested_template_instantiation->template_params())
argument.add_template_param(t); argument.add_template_param(t);
// Check if this template should be simplified (e.g. system // Check if this template should be simplified (e.g. system
@@ -2804,7 +2807,7 @@ void translation_unit_visitor::process_field(
if (template_field_type != nullptr) { if (template_field_type != nullptr) {
// Skip types which are template template parameters of the parent // Skip types which are template template parameters of the parent
// template // template
for (const auto &class_template_param : c.templates()) { for (const auto &class_template_param : c.template_params()) {
if (class_template_param.name() == if (class_template_param.name() ==
template_field_type->getTemplateName() template_field_type->getTemplateName()
.getAsTemplateDecl() .getAsTemplateDecl()
@@ -2851,7 +2854,7 @@ void translation_unit_visitor::process_field(
found_relationships_t nested_relationships; found_relationships_t nested_relationships;
if (!template_instantiation_added_as_aggregation) { if (!template_instantiation_added_as_aggregation) {
for (const auto &template_argument : for (const auto &template_argument :
template_specialization.templates()) { template_specialization.template_params()) {
LOG_DBG("Looking for nested relationships from {}::{} in " LOG_DBG("Looking for nested relationships from {}::{} in "
"template {}", "template {}",

View File

@@ -109,84 +109,36 @@ void template_parameter::is_variadic(bool is_variadic) noexcept
bool template_parameter::is_variadic() const noexcept { return is_variadic_; } bool template_parameter::is_variadic() const noexcept { return is_variadic_; }
int template_parameter::calculate_specialization_match( int template_parameter::calculate_specialization_match(
const template_parameter &ct) const const template_parameter &base_template_parameter) const
{ {
int res{0}; int res{0};
if (ct.type().has_value() && type().has_value() && if (base_template_parameter.type().has_value() && type().has_value() &&
!ct.is_template_parameter() && !is_template_parameter()) { !base_template_parameter.is_template_parameter() &&
!is_template_parameter()) {
if (ct.type().value() != type().value()) if (base_template_parameter.type().value() != type().value())
return 0; return 0;
else else
res++; res++;
} }
if (ct.is_function_template() && !is_function_template()) if (base_template_parameter.is_function_template() &&
!is_function_template())
return 0; return 0;
if (template_params().size() > 0 && ct.template_params().size() > 0) { if (!base_template_parameter.template_params().empty() &&
// More generic template params !template_params().empty()) {
const auto &template_params = ct.template_params(); auto params_match = calculate_template_params_specialization_match(
const auto &specialization_params = this->template_params(); template_params(), base_template_parameter.template_params());
auto template_index{0U};
auto arg_index{0U};
while (arg_index < specialization_params.size() && if (params_match == 0)
template_index < template_params.size()) {
auto match = specialization_params.at(arg_index)
.calculate_specialization_match(
template_params.at(template_index));
if (match == 0) {
return 0; return 0;
res += params_match;
} }
else if ((base_template_parameter.is_template_parameter() ||
if (!template_params.at(template_index).is_variadic()) base_template_parameter.is_template_template_parameter()) &&
template_index++;
res += match;
// Add 1 point if the current specialization param is an argument
// as it's a more specific match than 2 template params
if (!specialization_params.at(arg_index).is_template_parameter())
res++;
arg_index++;
}
if (arg_index == specialization_params.size()) {
// Check also backwards to make sure that trailing non-variadic
// params match after a variadic parameter
template_index = template_params.size() - 1;
arg_index = specialization_params.size() - 1;
while (true) {
auto match = specialization_params.at(arg_index)
.calculate_specialization_match(
template_params.at(template_index));
if (match == 0) {
return 0;
}
if (arg_index == 0 || template_index == 0)
break;
arg_index--;
if (!template_params.at(template_index).is_variadic())
template_index--;
else
break;
}
return res;
}
else
return 0;
}
if ((ct.is_template_parameter() || ct.is_template_template_parameter()) &&
!is_template_parameter()) !is_template_parameter())
return 1; return 1;
@@ -372,4 +324,76 @@ const std::optional<std::string> &template_parameter::concept_constraint() const
return concept_constraint_; return concept_constraint_;
} }
int calculate_template_params_specialization_match(
const std::vector<template_parameter> &specialization_params,
const std::vector<template_parameter> &template_params)
{
int res{0};
if (!specialization_params.empty() && !template_params.empty()) {
auto template_index{0U};
auto arg_index{0U};
while (arg_index < specialization_params.size() &&
template_index < template_params.size()) {
auto match = specialization_params.at(arg_index)
.calculate_specialization_match(
template_params.at(template_index));
if (match == 0) {
return 0;
}
// Add 1 point if the current specialization param is an argument
// as it's a more specific match than 2 template params
if (!specialization_params.at(arg_index).is_template_parameter())
res++;
// Add 1 point if the current template param is an argument
// as it's a more specific match than 2 template params
if (!template_params.at(template_index).is_template_parameter())
res++;
if (!template_params.at(template_index).is_variadic())
template_index++;
res += match;
arg_index++;
}
if (arg_index == specialization_params.size()) {
// Check also backwards to make sure that trailing non-variadic
// params match after a variadic parameter
template_index = template_params.size() - 1;
arg_index = specialization_params.size() - 1;
while (true) {
auto match = specialization_params.at(arg_index)
.calculate_specialization_match(
template_params.at(template_index));
if (match == 0) {
return 0;
}
if (arg_index == 0 || template_index == 0)
break;
arg_index--;
if (!template_params.at(template_index).is_variadic())
template_index--;
else
break;
}
return res;
}
else
return 0;
}
return 0;
}
} // namespace clanguml::common::model } // namespace clanguml::common::model

View File

@@ -122,7 +122,8 @@ public:
void is_variadic(bool is_variadic) noexcept; void is_variadic(bool is_variadic) noexcept;
bool is_variadic() const noexcept; bool is_variadic() const noexcept;
int calculate_specialization_match(const template_parameter &ct) const; int calculate_specialization_match(
const template_parameter &base_template_parameter) const;
friend bool operator==( friend bool operator==(
const template_parameter &l, const template_parameter &r); const template_parameter &l, const template_parameter &r);
@@ -221,4 +222,9 @@ private:
bool is_unexposed_{false}; bool is_unexposed_{false};
}; };
int calculate_template_params_specialization_match(
const std::vector<template_parameter> &specialization,
const std::vector<template_parameter> &base_template);
} // namespace clanguml::common::model } // namespace clanguml::common::model

View File

@@ -58,42 +58,16 @@ bool template_trait::is_implicit() const { return is_implicit_; }
void template_trait::set_implicit(bool implicit) { is_implicit_ = implicit; } void template_trait::set_implicit(bool implicit) { is_implicit_ = implicit; }
const std::vector<template_parameter> &template_trait::templates() const const std::vector<template_parameter> &template_trait::template_params() const
{ {
return templates_; return templates_;
} }
int template_trait::calculate_template_specialization_match( int template_trait::calculate_template_specialization_match(
const template_trait &other) const const template_trait &base_template) const
{ {
int res{0}; return calculate_template_params_specialization_match(
template_params(), base_template.template_params());
// Iterate over all template arguments
for (auto i = 0U; i < other.templates().size(); i++) {
const auto &template_arg = templates().at(i);
const auto &other_template_arg = other.templates().at(i);
if (template_arg == other_template_arg) {
res++;
if (!template_arg.is_template_parameter())
res++;
if (!other_template_arg.is_template_parameter())
res++;
}
else if (auto match = other_template_arg.calculate_specialization_match(
template_arg);
match > 0) {
res += match;
}
else {
res = 0;
break;
}
}
return res;
} }
} // namespace clanguml::common::model } // namespace clanguml::common::model

View File

@@ -36,7 +36,7 @@ public:
void add_template(template_parameter &&tmplt); void add_template(template_parameter &&tmplt);
const std::vector<template_parameter> &templates() const; const std::vector<template_parameter> &template_params() const;
int calculate_template_specialization_match( int calculate_template_specialization_match(
const template_trait &other) const; const template_trait &other) const;

View File

@@ -1715,7 +1715,8 @@ void translation_unit_visitor::process_template_specialization_argument(
argument.set_id(nested_template_instantiation->id()); argument.set_id(nested_template_instantiation->id());
for (const auto &t : nested_template_instantiation->templates()) for (const auto &t :
nested_template_instantiation->template_params())
argument.add_template_param(t); argument.add_template_param(t);
// Check if this template should be simplified (e.g. system // Check if this template should be simplified (e.g. system
@@ -2009,8 +2010,8 @@ translation_unit_visitor::build_template_instantiation(
auto c_full_name = participant_as_class->full_name(false); auto c_full_name = participant_as_class->full_name(false);
auto match = auto match =
participant_as_class->calculate_template_specialization_match( template_instantiation.calculate_template_specialization_match(
template_instantiation); *participant_as_class);
if (match > best_match) { if (match > best_match) {
best_match = match; best_match = match;

View File

@@ -19,6 +19,7 @@
#include "catch.h" #include "catch.h"
#include "class_diagram/model/class.h"
#include "common/model/namespace.h" #include "common/model/namespace.h"
#include "common/model/template_parameter.h" #include "common/model/template_parameter.h"
@@ -72,6 +73,48 @@ TEST_CASE("Test namespace_", "[unit-test]")
CHECK(ns8.relative(name) == "ccc<std::unique_ptr<ddd>>"); CHECK(ns8.relative(name) == "ccc<std::unique_ptr<ddd>>");
} }
TEST_CASE("Test class_::calculate_specialization_match", "[unit-test]")
{
using clanguml::class_diagram::model::class_;
using clanguml::common::model::template_parameter;
{
auto c = class_({});
c.set_name("A");
c.add_template(template_parameter::make_argument("int"));
c.add_template(template_parameter::make_argument("double"));
c.add_template(template_parameter::make_argument("int"));
auto t = class_({});
t.set_name("A");
t.add_template(
template_parameter::make_template_type("Args", {}, true));
t.add_template(template_parameter::make_argument("int"));
CHECK(c.calculate_template_specialization_match(t));
}
{
auto c = class_({});
c.set_name("A");
c.add_template(template_parameter::make_argument("double"));
c.add_template(template_parameter::make_argument("int"));
auto s = class_({});
s.set_name("A");
s.add_template(template_parameter::make_argument("double"));
s.add_template(template_parameter::make_template_type("V"));
auto t = class_({});
t.set_name("A");
t.add_template(template_parameter::make_template_type("T"));
t.add_template(template_parameter::make_template_type("V"));
CHECK(c.calculate_template_specialization_match(s) >
c.calculate_template_specialization_match(t));
}
}
TEST_CASE( TEST_CASE(
"Test template_parameter::calculate_specialization_match", "[unit-test]") "Test template_parameter::calculate_specialization_match", "[unit-test]")
{ {
@@ -81,7 +124,7 @@ TEST_CASE(
auto tp1 = template_parameter::make_template_type("T"); auto tp1 = template_parameter::make_template_type("T");
auto tp2 = template_parameter::make_argument("int"); auto tp2 = template_parameter::make_argument("int");
CHECK(tp2.calculate_specialization_match(tp1)); CHECK(tp2.calculate_specialization_match(tp1) > 0);
} }
{ {
@@ -150,7 +193,7 @@ TEST_CASE(
tp2.add_template_param(template_parameter::make_argument("int")); tp2.add_template_param(template_parameter::make_argument("int"));
tp2.add_template_param(template_parameter::make_argument("double")); tp2.add_template_param(template_parameter::make_argument("double"));
CHECK(!tp2.calculate_specialization_match(tp1)); CHECK(tp2.calculate_specialization_match(tp1) == 0);
} }
{ {
@@ -163,7 +206,7 @@ TEST_CASE(
tp2.add_template_param(template_parameter::make_argument("int")); tp2.add_template_param(template_parameter::make_argument("int"));
tp2.add_template_param(template_parameter::make_argument("double")); tp2.add_template_param(template_parameter::make_argument("double"));
CHECK(!tp2.calculate_specialization_match(tp1)); CHECK(tp2.calculate_specialization_match(tp1) == 0);
} }
{ {