diff --git a/src/sequence_diagram/model/participant.cc b/src/sequence_diagram/model/participant.cc index fe9ca514..d787857c 100644 --- a/src/sequence_diagram/model/participant.cc +++ b/src/sequence_diagram/model/participant.cc @@ -63,6 +63,38 @@ template_trait::templates() const return templates_; } +int template_trait::calculate_template_specialization_match( + const template_trait &other, const std::string &full_name) const +{ + int res{}; + +// std::string left = name_and_ns(); +// // TODO: handle variadic templates +// if ((name_and_ns() != full_name) || +// (templates().size() != other.templates().size())) { +// return res; +// } + + // Iterate over all template arguments + for (auto i = 0U; i < other.templates().size(); i++) { + const auto &template_arg = templates().at(i); + const auto &other_template_arg = other.templates().at(i); + + if (template_arg == other_template_arg) { + res++; + } + else if (other_template_arg.is_specialization_of(template_arg)) { + continue; + } + else { + res = 0; + break; + } + } + + return res; +} + class_::class_(const common::model::namespace_ &using_namespace) : participant{using_namespace} { @@ -122,6 +154,9 @@ std::string class_::full_name(bool relative) const return res; } +bool operator==(const class_ &l, const class_ &r) { return l.id() == r.id(); } + + function::function(const common::model::namespace_ &using_namespace) : participant{using_namespace} { diff --git a/src/sequence_diagram/model/participant.h b/src/sequence_diagram/model/participant.h index e8a77efd..a2048ff2 100644 --- a/src/sequence_diagram/model/participant.h +++ b/src/sequence_diagram/model/participant.h @@ -39,6 +39,9 @@ struct template_trait { const std::vector & templates() const; + int calculate_template_specialization_match( + const template_trait &other, const std::string &full_name) const; + private: std::vector templates_; std::string base_template_full_name_; diff --git a/src/sequence_diagram/visitor/translation_unit_visitor.cc b/src/sequence_diagram/visitor/translation_unit_visitor.cc index 8db85314..2bb101ab 100644 --- a/src/sequence_diagram/visitor/translation_unit_visitor.cc +++ b/src/sequence_diagram/visitor/translation_unit_visitor.cc @@ -192,6 +192,60 @@ bool translation_unit_visitor::VisitClassTemplateDecl( return true; } +bool translation_unit_visitor::VisitClassTemplateSpecializationDecl( + clang::ClassTemplateSpecializationDecl *cls) +{ + if (source_manager().isInSystemHeader(cls->getSourceRange().getBegin())) + return true; + + if (!diagram().should_include(cls->getQualifiedNameAsString())) + return true; + + if(cls->isImplicit()) + LOG_DBG("!!!!!!!!!!!!!!!!!!!!!"); + + LOG_DBG("= Visiting template specialization declaration {} at {}", + cls->getQualifiedNameAsString(), + cls->getLocation().printToString(source_manager())); + + // TODO: Add support for classes defined in function/method bodies + if (cls->isLocalClass()) + return true; + + auto template_specialization_ptr = process_template_specialization(cls); + + if (!template_specialization_ptr) + return true; + + const auto cls_full_name = template_specialization_ptr->full_name(false); + const auto id = common::to_id(cls_full_name); + + template_specialization_ptr->set_id(id); + + set_ast_local_id(cls->getID(), id); + + if (!cls->isCompleteDefinition()) { + forward_declarations_.emplace( + id, std::move(template_specialization_ptr)); + return true; + } + else { + forward_declarations_.erase(id); + } + + if (diagram_.should_include(*template_specialization_ptr)) { + LOG_DBG("Adding class template specialization {} with id {}", + cls_full_name, id); + + call_expression_context_.set_caller_id(id); + call_expression_context_.update(cls); + + diagram_.add_participant(std::move(template_specialization_ptr)); + } + + return true; +} + bool translation_unit_visitor::VisitCXXMethodDecl(clang::CXXMethodDecl *m) { if (call_expression_context_.current_class_decl_ == nullptr && @@ -385,18 +439,45 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr) const auto *method_decl = method_call_expr->getMethodDecl(); std::string method_name = method_decl->getQualifiedNameAsString(); - const auto *callee_decl = - method_decl ? method_decl->getParent() : nullptr; + auto *callee_decl = method_decl ? method_decl->getParent() : nullptr; if (!(callee_decl && diagram().should_include( callee_decl->getQualifiedNameAsString()))) return true; - // TODO: The method can be called before it's declaration has been - // encountered by the visitor - for now it's not a problem - // as overloaded methods are not supported - m.to = common::to_id(method_decl->getQualifiedNameAsString()); + const auto *callee_template_specialization = + clang::dyn_cast( + callee_decl); + + if (callee_template_specialization) { + LOG_DBG("Callee is a template specialization declaration {}", + callee_template_specialization->getQualifiedNameAsString()); + + if (!get_ast_local_id(callee_template_specialization->getID())) { + callee_template_specialization->dump(); + + + call_expression_context context_backup = call_expression_context_; + + // Since this visitor will overwrite the call_expression_context_ + // we need to back it up and restore it later + VisitClassTemplateSpecializationDecl( + const_cast( + callee_template_specialization)); + + call_expression_context_ = context_backup; + } + + m.to = get_ast_local_id(callee_template_specialization->getID()) + .value(); + } + else { + // TODO: The method can be called before it's declaration has been + // encountered by the visitor - for now it's not a problem + // as overloaded methods are not supported + m.to = common::to_id(method_decl->getQualifiedNameAsString()); + } m.message_name = method_decl->getNameAsString(); m.return_type = method_call_expr->getCallReturnType(current_ast_context) .getAsString(); @@ -800,7 +881,7 @@ translation_unit_visitor::build_function_template_instantiation( // // Instantiate the template arguments // - std::optional parent; + model::template_trait *parent{nullptr}; build_template_instantiation_process_template_arguments(parent, template_base_params, decl.getTemplateSpecializationArgs()->asArray(), template_instantiation, "", decl.getPrimaryTemplate()); @@ -810,7 +891,7 @@ translation_unit_visitor::build_function_template_instantiation( void translation_unit_visitor:: build_template_instantiation_process_template_arguments( - std::optional &parent, + model::template_trait *parent, std::deque> &template_base_params, const clang::ArrayRef &template_args, model::template_trait &template_instantiation, @@ -903,7 +984,7 @@ void translation_unit_visitor:: void translation_unit_visitor:: build_template_instantiation_process_type_argument( - std::optional &parent, + model::template_trait *parent, const std::string &full_template_specialization_name, const clang::TemplateDecl *template_decl, const clang::TemplateArgument &arg, @@ -987,6 +1068,400 @@ void translation_unit_visitor:: } } +std::unique_ptr +translation_unit_visitor::process_template_specialization( + clang::ClassTemplateSpecializationDecl *cls) +{ + auto c_ptr{std::make_unique(config_.using_namespace())}; + auto &template_instantiation = *c_ptr; + + // TODO: refactor to method get_qualified_name() + auto qualified_name = cls->getQualifiedNameAsString(); + util::replace_all(qualified_name, "(anonymous namespace)", ""); + util::replace_all(qualified_name, "::::", "::"); + + common::model::namespace_ ns{qualified_name}; + ns.pop_back(); + template_instantiation.set_name(cls->getNameAsString()); + template_instantiation.set_namespace(ns); + + template_instantiation.is_struct(cls->isStruct()); + + process_comment(*cls, template_instantiation); + set_source_location(*cls, template_instantiation); + + if (template_instantiation.skip()) + return {}; + + const auto template_args_count = cls->getTemplateArgs().size(); + for (auto arg_it = 0U; arg_it < template_args_count; arg_it++) { + const auto arg = cls->getTemplateArgs().get(arg_it); + process_template_specialization_argument( + cls, template_instantiation, arg, arg_it); + } + + template_instantiation.set_id( + common::to_id(template_instantiation.full_name(false))); + + set_ast_local_id(cls->getID(), template_instantiation.id()); + + return c_ptr; +} + +void translation_unit_visitor::process_template_specialization_argument( + const clang::ClassTemplateSpecializationDecl *cls, + model::class_ &template_instantiation, const clang::TemplateArgument &arg, + size_t argument_index, bool in_parameter_pack) +{ + const auto argument_kind = arg.getKind(); + + if (argument_kind == clang::TemplateArgument::Type) { + class_diagram::model::template_parameter argument; + argument.is_template_parameter(false); + + // If this is a nested template type - add nested templates as + // template arguments + if (arg.getAsType()->getAs()) { + const auto *nested_template_type = + arg.getAsType()->getAs(); + + const auto nested_template_name = + nested_template_type->getTemplateName() + .getAsTemplateDecl() + ->getQualifiedNameAsString(); + + argument.set_name(nested_template_name); + + auto nested_template_instantiation = build_template_instantiation( + *arg.getAsType()->getAs(), + &template_instantiation); + + argument.set_id(nested_template_instantiation->id()); + + for (const auto &t : nested_template_instantiation->templates()) + argument.add_template_param(t); + + // Check if this template should be simplified (e.g. system + // template aliases such as 'std:basic_string' should be + // simply 'std::string') + simplify_system_template(argument, + argument.to_string(config().using_namespace(), false)); + } + else if (arg.getAsType()->getAs()) { + auto type_name = + common::to_string(arg.getAsType(), cls->getASTContext()); + + // clang does not provide declared template parameter/argument + // names in template specializations - so we have to extract + // them from raw source code... + if (type_name.find("type-parameter-") == 0) { + auto declaration_text = common::get_source_text_raw( + cls->getSourceRange(), source_manager()); + + declaration_text = declaration_text.substr( + declaration_text.find(cls->getNameAsString()) + + cls->getNameAsString().size() + 1); + + auto template_params = + cx::util::parse_unexposed_template_params( + declaration_text, [](const auto &t) { return t; }); + + if (template_params.size() > argument_index) + type_name = template_params[argument_index].to_string( + config().using_namespace(), false); + else { + LOG_DBG("Failed to find type specialization for argument " + "{} at index {} in declaration \n===\n{}\n===\n", + type_name, argument_index, declaration_text); + } + } + + argument.set_name(type_name); + } + else { + auto type_name = + common::to_string(arg.getAsType(), cls->getASTContext()); + if (type_name.find('<') != std::string::npos) { + // Sometimes template instantiation is reported as + // RecordType in the AST and getAs to + // TemplateSpecializationType returns null pointer so we + // have to at least make sure it's properly formatted + // (e.g. std:integral_constant, or any template + // specialization which contains it - see t00038) + process_unexposed_template_specialization_parameters( + type_name.substr(type_name.find('<') + 1, + type_name.size() - (type_name.find('<') + 2)), + argument, template_instantiation); + + argument.set_name(type_name.substr(0, type_name.find('<'))); + } + else if (type_name.find("type-parameter-") == 0) { + auto declaration_text = common::get_source_text_raw( + cls->getSourceRange(), source_manager()); + + declaration_text = declaration_text.substr( + declaration_text.find(cls->getNameAsString()) + + cls->getNameAsString().size() + 1); + + auto template_params = + cx::util::parse_unexposed_template_params( + declaration_text, [](const auto &t) { return t; }); + + if (template_params.size() > argument_index) + type_name = template_params[argument_index].to_string( + config().using_namespace(), false); + else { + LOG_DBG("Failed to find type specialization for argument " + "{} at index {} in declaration \n===\n{}\n===\n", + type_name, argument_index, declaration_text); + } + + // Otherwise just set the name for the template argument to + // whatever clang says + argument.set_name(type_name); + } + else + argument.set_name(type_name); + } + + LOG_DBG("Adding template instantiation argument {}", + argument.to_string(config().using_namespace(), false)); + + simplify_system_template( + argument, argument.to_string(config().using_namespace(), false)); + + template_instantiation.add_template(std::move(argument)); + } + else if (argument_kind == clang::TemplateArgument::Integral) { + class_diagram::model::template_parameter argument; + argument.is_template_parameter(false); + argument.set_type(std::to_string(arg.getAsIntegral().getExtValue())); + template_instantiation.add_template(std::move(argument)); + } + else if (argument_kind == clang::TemplateArgument::Expression) { + class_diagram::model::template_parameter argument; + argument.is_template_parameter(false); + argument.set_type(common::get_source_text( + arg.getAsExpr()->getSourceRange(), source_manager())); + template_instantiation.add_template(std::move(argument)); + } + else if (argument_kind == clang::TemplateArgument::TemplateExpansion) { + class_diagram::model::template_parameter argument; + argument.is_template_parameter(true); + + cls->getLocation().dump(source_manager()); + } + else if (argument_kind == clang::TemplateArgument::Pack) { + // This will only work for now if pack is at the end + size_t argument_pack_index{argument_index}; + for (const auto &template_argument : arg.getPackAsArray()) { + process_template_specialization_argument(cls, + template_instantiation, template_argument, + argument_pack_index++, true); + } + } + else { + LOG_ERROR("Unsupported template argument kind {} [{}]", arg.getKind(), + cls->getLocation().printToString(source_manager())); + } +} + +std::unique_ptr +translation_unit_visitor::build_template_instantiation( + const clang::TemplateSpecializationType &template_type_decl, + model::class_ *parent) +{ + // TODO: Make sure we only build instantiation once + + // + // Here we'll hold the template base params to replace with the + // instantiated values + // + std::deque> + template_base_params{}; + + auto *template_type_ptr = &template_type_decl; + if (template_type_decl.isTypeAlias() && + template_type_decl.getAliasedType() + ->getAs()) + template_type_ptr = template_type_decl.getAliasedType() + ->getAs(); + + auto &template_type = *template_type_ptr; + + // + // Create class_ instance to hold the template instantiation + // + auto template_instantiation_ptr = + std::make_unique(config_.using_namespace()); + auto &template_instantiation = *template_instantiation_ptr; + std::string full_template_specialization_name = common::to_string( + template_type.desugar(), + template_type.getTemplateName().getAsTemplateDecl()->getASTContext()); + + auto *template_decl{template_type.getTemplateName().getAsTemplateDecl()}; + + auto template_decl_qualified_name = + template_decl->getQualifiedNameAsString(); + + auto *class_template_decl{ + clang::dyn_cast(template_decl)}; + + if (class_template_decl && class_template_decl->getTemplatedDecl() && + class_template_decl->getTemplatedDecl()->getParent() && + class_template_decl->getTemplatedDecl()->getParent()->isRecord()) { + + common::model::namespace_ ns{ + common::get_tag_namespace(*class_template_decl->getTemplatedDecl() + ->getParent() + ->getOuterLexicalRecordContext())}; + + std::string ns_str = ns.to_string(); + std::string name = template_decl->getQualifiedNameAsString(); + if (!ns_str.empty()) { + name = name.substr(ns_str.size() + 2); + } + + util::replace_all(name, "::", "##"); + template_instantiation.set_name(name); + + template_instantiation.set_namespace(ns); + } + else { + common::model::namespace_ ns{template_decl_qualified_name}; + ns.pop_back(); + template_instantiation.set_name(template_decl->getNameAsString()); + template_instantiation.set_namespace(ns); + } + + // TODO: Refactor handling of base parameters to a separate method + + // We need this to match any possible base classes coming from template + // arguments + std::vector< + std::pair> + template_parameter_names{}; + + for (const auto *parameter : *template_decl->getTemplateParameters()) { + if (parameter->isTemplateParameter() && + (parameter->isTemplateParameterPack() || + parameter->isParameterPack())) { + template_parameter_names.emplace_back( + parameter->getNameAsString(), true); + } + else + template_parameter_names.emplace_back( + parameter->getNameAsString(), false); + } + + // Check if the primary template has any base classes + int base_index = 0; + + const auto *templated_class_decl = + clang::dyn_cast_or_null( + template_decl->getTemplatedDecl()); + + if (templated_class_decl && templated_class_decl->hasDefinition()) + for (const auto &base : templated_class_decl->bases()) { + const auto base_class_name = common::to_string( + base.getType(), templated_class_decl->getASTContext(), false); + + LOG_DBG("Found template instantiation base: {}, {}", + base_class_name, base_index); + + // Check if any of the primary template arguments has a + // parameter equal to this type + auto it = std::find_if(template_parameter_names.begin(), + template_parameter_names.end(), + [&base_class_name]( + const auto &p) { return p.first == base_class_name; }); + + if (it != template_parameter_names.end()) { + const auto ¶meter_name = it->first; + const bool is_variadic = it->second; + // Found base class which is a template parameter + LOG_DBG("Found base class which is a template parameter " + "{}, {}, {}", + parameter_name, is_variadic, + std::distance(template_parameter_names.begin(), it)); + + template_base_params.emplace_back(parameter_name, + std::distance(template_parameter_names.begin(), it), + is_variadic); + } + else { + // This is a regular base class - it is handled by + // process_template + } + base_index++; + } + + build_template_instantiation_process_template_arguments(parent, + template_base_params, template_type.template_arguments(), + template_instantiation, full_template_specialization_name, + template_decl); + + // First try to find the best match for this template in partially + // specialized templates + std::string destination{}; + std::string best_match_full_name{}; + auto full_template_name = template_instantiation.full_name(false); + int best_match{}; + common::model::diagram_element::id_t best_match_id{0}; + + for (const auto &[id, c] : diagram().participants) { + const auto *participant_as_class = + dynamic_cast(c.get()); + if ((participant_as_class != nullptr) && + (*participant_as_class == template_instantiation)) + continue; + + auto c_full_name = participant_as_class->full_name(false); + auto match = + participant_as_class->calculate_template_specialization_match( + template_instantiation, template_instantiation.name_and_ns()); + + if (match > best_match) { + best_match = match; + best_match_full_name = c_full_name; + best_match_id = participant_as_class->id(); + } + } + + auto templated_decl_id = + template_type.getTemplateName().getAsTemplateDecl()->getID(); + // auto templated_decl_local_id = + // get_ast_local_id(templated_decl_id).value_or(0); + + if (best_match_id > 0) { + destination = best_match_full_name; + } + else { + LOG_DBG("== Cannot determine global id for specialization template {} " + "- delaying until the translation unit is complete ", + templated_decl_id); + } + + template_instantiation.set_id( + common::to_id(template_instantiation_ptr->full_name(false))); + + return template_instantiation_ptr; +} + +void translation_unit_visitor:: + process_unexposed_template_specialization_parameters( + const std::string &type_name, + class_diagram::model::template_parameter &tp, model::class_ &c) const +{ + auto template_params = cx::util::parse_unexposed_template_params( + type_name, [](const std::string &t) { return t; }); + + for (auto ¶m : template_params) { + tp.add_template_param(param); + } +} + bool translation_unit_visitor::simplify_system_template( class_diagram::model::template_parameter &ct, const std::string &full_name) { diff --git a/src/sequence_diagram/visitor/translation_unit_visitor.h b/src/sequence_diagram/visitor/translation_unit_visitor.h index afb408dc..f3cdafa6 100644 --- a/src/sequence_diagram/visitor/translation_unit_visitor.h +++ b/src/sequence_diagram/visitor/translation_unit_visitor.h @@ -169,7 +169,10 @@ public: virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *cls); - bool VisitClassTemplateDecl(clang::ClassTemplateDecl *cls); + virtual bool VisitClassTemplateDecl(clang::ClassTemplateDecl *cls); + + virtual bool VisitClassTemplateSpecializationDecl( + clang::ClassTemplateSpecializationDecl *cls); virtual bool VisitFunctionDecl(clang::FunctionDecl *function_declaration); @@ -203,7 +206,7 @@ private: build_function_template_instantiation(const clang::FunctionDecl &pDecl); void build_template_instantiation_process_template_arguments( - std::optional &parent, + model::template_trait *parent, std::deque> &template_base_params, const clang::ArrayRef &template_args, model::template_trait &template_instantiation, @@ -230,13 +233,30 @@ private: class_diagram::model::template_parameter &argument) const; void build_template_instantiation_process_type_argument( - std::optional &parent, + model::template_trait *parent, const std::string &full_template_specialization_name, const clang::TemplateDecl *template_decl, const clang::TemplateArgument &arg, model::template_trait &template_instantiation, class_diagram::model::template_parameter &argument); + std::unique_ptr process_template_specialization( + clang::ClassTemplateSpecializationDecl *cls); + + void process_template_specialization_argument( + const clang::ClassTemplateSpecializationDecl *cls, + model::class_ &template_instantiation, + const clang::TemplateArgument &arg, size_t argument_index, + bool in_parameter_pack = false); + + void process_unexposed_template_specialization_parameters( + const std::string &type_name, + class_diagram::model::template_parameter &tp, model::class_ &c) const; + + std::unique_ptr build_template_instantiation( + const clang::TemplateSpecializationType &template_type_decl, + model::class_ *parent); + bool simplify_system_template(class_diagram::model::template_parameter &ct, const std::string &full_name); diff --git a/tests/t20001/t20001.cc b/tests/t20001/t20001.cc index 18878c11..98135a05 100644 --- a/tests/t20001/t20001.cc +++ b/tests/t20001/t20001.cc @@ -64,7 +64,9 @@ int tmain() A a; B b(a); - return b.wrap_add3(1, 2, 3); + auto tmp = a.add(1, 2); + + return b.wrap_add3(tmp, 2, 3); } } } diff --git a/tests/t20006/.clang-uml b/tests/t20006/.clang-uml new file mode 100644 index 00000000..dedfb8f7 --- /dev/null +++ b/tests/t20006/.clang-uml @@ -0,0 +1,14 @@ +compilation_database_dir: .. +output_directory: puml +diagrams: + t20006_sequence: + type: sequence + glob: + - ../../tests/t20006/t20006.cc + include: + namespaces: + - clanguml::t20006 + using_namespace: + - clanguml::t20006 + start_from: + - function: "clanguml::t20006::tmain()" \ No newline at end of file diff --git a/tests/t20006/t20006.cc b/tests/t20006/t20006.cc new file mode 100644 index 00000000..bcd278d6 --- /dev/null +++ b/tests/t20006/t20006.cc @@ -0,0 +1,30 @@ +#include + +namespace clanguml { +namespace t20006 { + +template struct A { + T a(T arg) { return arg; } + T a1(T arg) { return arg; } +}; + +template struct B { + T b(T arg) { return a_.a(arg); } + A a_; +}; + +template <> struct B { + std::string b(std::string arg) { return arg; } + A a_; +}; + +void tmain() +{ + B bint; + B bstring; + + bint.b(1); + bstring.b("bstring"); +} +} +} \ No newline at end of file diff --git a/tests/t20006/test_case.h b/tests/t20006/test_case.h new file mode 100644 index 00000000..d50da514 --- /dev/null +++ b/tests/t20006/test_case.h @@ -0,0 +1,47 @@ +/** + * tests/t20006/test_case.h + * + * Copyright (c) 2021-2022 Bartek Kryza + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +TEST_CASE("t20006", "[test-case][sequence]") +{ + auto [config, db] = load_config("t20006"); + + auto diagram = config.diagrams["t20006_sequence"]; + + REQUIRE(diagram->name == "t20006_sequence"); + + auto model = generate_sequence_diagram(*db, diagram); + + REQUIRE(model->name() == "t20006_sequence"); + + auto puml = generate_sequence_puml(diagram, *model); + AliasMatcher _A(puml); + + REQUIRE_THAT(puml, StartsWith("@startuml")); + REQUIRE_THAT(puml, EndsWith("@enduml\n")); + + // Check if all calls exist + REQUIRE_THAT(puml, HasCall(_A("tmain()"), _A("B"), "b")); + REQUIRE_THAT(puml, HasCall(_A("B"), _A("A"), "a")); + + REQUIRE_THAT(puml, HasCall(_A("tmain()"), _A("B"), "b")); + REQUIRE_THAT( + puml, !HasCall(_A("B"), _A("A"), "a")); + + save_puml( + "./" + config.output_directory() + "/" + diagram->name + ".puml", puml); +} \ No newline at end of file diff --git a/tests/test_cases.cc b/tests/test_cases.cc index 0433f165..79289f4b 100644 --- a/tests/test_cases.cc +++ b/tests/test_cases.cc @@ -252,6 +252,7 @@ using namespace clanguml::test::matchers; #include "t20003/test_case.h" #include "t20004/test_case.h" #include "t20005/test_case.h" +#include "t20006/test_case.h" /// /// Package diagram tests