diff --git a/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.cc b/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.cc index 2c3a6b79..a425c010 100644 --- a/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.cc +++ b/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.cc @@ -78,10 +78,10 @@ void generator::generate_return(const message &m, std::ostream &ostr) const } void generator::generate_activity(const activity &a, std::ostream &ostr, - std::set &visited) const + std::vector &visited) const { for (const auto &m : a.messages) { - visited.emplace(m.from); + visited.push_back(m.from); const auto &to = m_model.get_participant(m.to); if (!to) @@ -94,12 +94,16 @@ void generator::generate_activity(const activity &a, std::ostream &ostr, ostr << "activate " << to.value().alias() << std::endl; if (m_model.sequences.find(m.to) != m_model.sequences.end()) { - if (visited.find(m.to) == + if (std::find(visited.begin(), visited.end(), m.to) == visited.end()) { // break infinite recursion on recursive calls LOG_DBG("Creating activity {} --> {} - missing sequence {}", m.from, m.to, m.to); generate_activity(m_model.sequences[m.to], ostr, visited); } +// else { +// // clear the visited list after breaking the loop +// visited.clear(); +// } } else LOG_DBG("Skipping activity {} --> {} - missing sequence {}", m.from, @@ -107,6 +111,8 @@ void generator::generate_activity(const activity &a, std::ostream &ostr, generate_return(m, ostr); + visited.pop_back(); + ostr << "deactivate " << to.value().alias() << std::endl; } } @@ -182,9 +188,18 @@ void generator::generate(std::ostream &ostr) const break; } } - std::set visited_participants; + std::vector visited_participants; + + const auto& from = m_model.get_participant(start_from); + + generate_participant(ostr, start_from); + + ostr << "activate " << from.value().alias() << std::endl; + generate_activity( m_model.sequences[start_from], ostr, visited_participants); + + ostr << "deactivate " << from.value().alias() << std::endl; } else { // TODO: Add support for other sequence start location types diff --git a/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.h b/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.h index fc3c82b9..e2148bee 100644 --- a/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.h +++ b/src/sequence_diagram/generators/plantuml/sequence_diagram_generator.h @@ -56,7 +56,7 @@ public: void generate_activity(const clanguml::sequence_diagram::model::activity &a, std::ostream &ostr, - std::set &visited) const; + std::vector &visited) const; void generate(std::ostream &ostr) const; diff --git a/src/sequence_diagram/model/participant.h b/src/sequence_diagram/model/participant.h index 2720bef5..d2abc452 100644 --- a/src/sequence_diagram/model/participant.h +++ b/src/sequence_diagram/model/participant.h @@ -115,11 +115,16 @@ public: void is_alias(bool alias) { is_alias_ = alias; } + bool is_lambda() const { return is_lambda_; } + + void is_lambda(bool is_lambda) { is_lambda_ = is_lambda; } + private: bool is_struct_{false}; bool is_template_{false}; bool is_template_instantiation_{false}; bool is_alias_{false}; + bool is_lambda_{false}; std::map type_aliases_; @@ -127,6 +132,12 @@ private: std::string full_name_; }; +struct lambda : public class_ { + using class_::class_; + + std::string type_name() const override { return "lambda"; } +}; + struct function : public participant { function(const common::model::namespace_ &using_namespace); diff --git a/src/sequence_diagram/visitor/translation_unit_visitor.cc b/src/sequence_diagram/visitor/translation_unit_visitor.cc index b0bb88d3..409b84fd 100644 --- a/src/sequence_diagram/visitor/translation_unit_visitor.cc +++ b/src/sequence_diagram/visitor/translation_unit_visitor.cc @@ -99,6 +99,9 @@ bool translation_unit_visitor::VisitCXXRecordDecl(clang::CXXRecordDecl *cls) if (cls->isLocalClass()) return true; + LOG_DBG("Visiting class declaration at {}", + cls->getBeginLoc().printToString(source_manager())); + // Build the class declaration and store it in the diagram, even // if we don't need it for any of the participants of this diagram auto c_ptr = create_class_declaration(cls); @@ -400,6 +403,73 @@ bool translation_unit_visitor::VisitFunctionTemplateDecl( return true; } +bool translation_unit_visitor::VisitLambdaExpr(clang::LambdaExpr *expr) +{ + const auto lambda_full_name = + expr->getLambdaClass()->getCanonicalDecl()->getNameAsString(); + + LOG_DBG("Visiting lambda expression {} at {}", lambda_full_name, + expr->getBeginLoc().printToString(source_manager())); + + LOG_DBG("Lambda call operator ID {} - lambda class ID {}, class call " + "operator ID {}", + expr->getCallOperator()->getID(), expr->getLambdaClass()->getID(), + expr->getLambdaClass()->getLambdaCallOperator()->getID()); + + // Create lambda class participant + auto *cls = expr->getLambdaClass(); + auto c_ptr = create_class_declaration(cls); + + if (!c_ptr) + return true; + + const auto cls_id = c_ptr->id(); + + set_unique_id(cls->getID(), cls_id); + + // Create lambda class operator() participant + auto m_ptr = std::make_unique( + config().using_namespace()); + + common::model::namespace_ ns{c_ptr->get_namespace()}; + auto method_name = "operator()"; + m_ptr->set_method_name(method_name); + ns.pop_back(); + + m_ptr->set_class_id(cls_id); + m_ptr->set_class_full_name(c_ptr->full_name(false)); + + diagram().add_participant(std::move(c_ptr)); + + m_ptr->set_id(common::to_id( + get_participant(cls_id).value().full_name(false) + "::" + method_name)); + + context().enter_lambda_expression(m_ptr->id()); + + set_unique_id(expr->getCallOperator()->getID(), m_ptr->id()); + + diagram().add_participant(std::move(m_ptr)); + + [[maybe_unused]] const auto is_generic_lambda = expr->isGenericLambda(); + + return true; +} + +bool translation_unit_visitor::TraverseLambdaExpr(clang::LambdaExpr *expr) +{ + const auto lambda_full_name = + expr->getLambdaClass()->getCanonicalDecl()->getNameAsString(); + + RecursiveASTVisitor::TraverseLambdaExpr(expr); + + LOG_DBG("Leaving lambda expression {} at {}", lambda_full_name, + expr->getBeginLoc().printToString(source_manager())); + + context().leave_lambda_expression(); + + return true; +} + bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr) { using clanguml::common::model::message_t; @@ -424,6 +494,10 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr) m.type = message_t::kCall; m.from = context().caller_id(); + if (context().lambda_caller_id() != 0) { + m.from = context().lambda_caller_id(); + } + const auto ¤t_ast_context = *context().get_ast_context(); LOG_DBG("Visiting call expression at {}", @@ -433,6 +507,26 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr) clang::dyn_cast_or_null(expr); operator_call_expr != nullptr) { // TODO: Handle C++ operator calls + + LOG_DBG("Operator call expression to {} at {}", + expr->getCalleeDecl()->getID(), + expr->getBeginLoc().printToString(source_manager())); + + auto maybe_id = get_unique_id(expr->getCalleeDecl()->getID()); + if (maybe_id.has_value()) { + // Found operator() call to a participant + // auto maybe_participant = get_participant(maybe_id.value()); + // if (maybe_participant.has_value()) { + m.to = maybe_id.value(); + m.message_name = "operator()"; + //} + } + else { + m.to = expr->getCalleeDecl()->getID(); + m.message_name = "operator()"; + } + + if (clang::dyn_cast(expr)) { } } // // Call to a class method @@ -637,6 +731,9 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr) auto *ftd = clang::dyn_cast_or_null< clang::FunctionTemplateDecl>(decl); + if (!get_unique_id(ftd->getID()).has_value()) + continue; + m.to = get_unique_id(ftd->getID()).value(); auto message_name = diagram() @@ -751,11 +848,16 @@ translation_unit_visitor::create_class_declaration(clang::CXXRecordDecl *cls) auto qualified_name = cls->getQualifiedNameAsString(); // common::get_qualified_name(*cls); - if (!diagram().should_include(qualified_name)) - return {}; + if (!cls->isLambda()) + if (!diagram().should_include(qualified_name)) + return {}; auto ns = common::get_tag_namespace(*cls); + if (cls->isLambda() && + !diagram().should_include(ns.to_string() + "::lambda")) + return {}; + const auto *parent = cls->getParent(); if (parent && parent->isRecord()) { @@ -815,6 +917,23 @@ translation_unit_visitor::create_class_declaration(clang::CXXRecordDecl *cls) c.nested(true); } + else if (cls->isLambda()) { + c.is_lambda(true); + if (cls->getParent()) { + auto parent_full_name = get_participant(context().caller_id()) + .value() + .full_name_no_ns(); + + const auto location = cls->getLocation(); + const auto type_name = + fmt::format("{}##(lambda {}:{})", parent_full_name, + source_manager().getSpellingLineNumber(location), + source_manager().getSpellingColumnNumber(location)); + c.set_name(type_name); + c.set_namespace(ns); + c.set_id(common::to_id(c.full_name(false))); + } + } else { c.set_name(common::get_tag_name(*cls)); c.set_namespace(ns); @@ -1208,6 +1327,35 @@ void translation_unit_visitor::process_template_specialization_argument( simplify_system_template(argument, argument.to_string(config().using_namespace(), false)); } + else if (arg.getAsType()->getAsCXXRecordDecl()) { + if (arg.getAsType()->getAsCXXRecordDecl()->isLambda()) { + if (get_unique_id( + arg.getAsType()->getAsCXXRecordDecl()->getID()) + .has_value()) { + argument.set_name(get_participant( + get_unique_id( + arg.getAsType()->getAsCXXRecordDecl()->getID()) + .value()) + .value() + .full_name(false)); + } + else { + auto parent_full_name = + get_participant(context().caller_id()) + .value() + .full_name_no_ns(); + + const auto location = + arg.getAsType()->getAsCXXRecordDecl()->getLocation(); + const auto type_name = + fmt::format("{}##(lambda {}:{})", parent_full_name, + source_manager().getSpellingLineNumber(location), + source_manager().getSpellingColumnNumber(location)); + + argument.set_name(type_name); + } + } + } else if (arg.getAsType()->getAs()) { auto type_name = common::to_string(arg.getAsType(), cls->getASTContext()); @@ -1312,6 +1460,41 @@ void translation_unit_visitor::process_template_specialization_argument( cls->getLocation().dump(source_manager()); } + // else if (arg.getKind() == clang::TemplateArgument::Expression) { + // if (clang::dyn_cast(arg.getAsExpr()) != + // nullptr) { + // class_diagram::model::template_parameter argument; + //// const auto location = + //// arg.getAsType()->getAsCXXRecordDecl()->getLocation(); + //// + //// auto type_name = fmt::format("(lambda {}:{}:{})", + //// source_manager().getFilename(location).str(), + //// source_manager().getSpellingLineNumber(location), + //// source_manager().getSpellingColumnNumber(location)); + //// + //// argument.set_name(type_name); + // + // if (get_unique_id( + // arg.getAsType()->getAsCXXRecordDecl()->getID()) + // .has_value()) { + // argument.set_name(get_participant( + // get_unique_id( + // arg.getAsType()->getAsCXXRecordDecl()->getID()) + // .value()) + // .value() + // .full_name(false)); + // } + // else { + // const auto location = + // arg.getAsType()->getAsCXXRecordDecl()->getLocation(); + // auto type_name = fmt::format("(lambda {}:{}:{})", + // source_manager().getFilename(location).str(), + // source_manager().getSpellingLineNumber(location), + // source_manager().getSpellingColumnNumber(location)); + // argument.set_name(type_name); + // } + // } + // } else if (argument_kind == clang::TemplateArgument::Pack) { // This will only work for now if pack is at the end size_t argument_pack_index{argument_index}; @@ -1534,4 +1717,15 @@ bool translation_unit_visitor::simplify_system_template( else return false; } + +void translation_unit_visitor::finalize() +{ + for (auto &[id, activity] : diagram().sequences) { + for (auto &m : activity.messages) { + if (local_ast_id_map_.find(m.to) != local_ast_id_map_.end()) { + m.to = local_ast_id_map_.at(m.to); + } + } + } +} } diff --git a/src/sequence_diagram/visitor/translation_unit_visitor.h b/src/sequence_diagram/visitor/translation_unit_visitor.h index 02d42516..a5223121 100644 --- a/src/sequence_diagram/visitor/translation_unit_visitor.h +++ b/src/sequence_diagram/visitor/translation_unit_visitor.h @@ -25,6 +25,8 @@ #include #include +#include + namespace clanguml::sequence_diagram::visitor { std::string to_string(const clang::FunctionTemplateDecl *decl); @@ -150,12 +152,39 @@ struct call_expression_context { std::int64_t caller_id() const { return current_caller_id_; } + std::int64_t lambda_caller_id() const + { + if(current_lambda_caller_id_.empty()) + return 0; + + return current_lambda_caller_id_.top(); + } + void set_caller_id(std::int64_t id) { LOG_DBG("Setting current caller id to {}", id); current_caller_id_ = id; } + void enter_lambda_expression(std::int64_t id) + { + LOG_DBG("Setting current lambda caller id to {}", id); + + assert(id != 0); + + current_lambda_caller_id_.push(id); + } + + void leave_lambda_expression() + { + assert(!current_lambda_caller_id_.empty()); + + LOG_DBG("Leaving current lambda expression id to {}", + current_lambda_caller_id_.top()); + + current_lambda_caller_id_.pop(); + } + clang::CXXRecordDecl *current_class_decl_; clang::ClassTemplateDecl *current_class_template_decl_; clang::ClassTemplateSpecializationDecl @@ -166,6 +195,7 @@ struct call_expression_context { private: std::int64_t current_caller_id_; + std::stack current_lambda_caller_id_; }; class translation_unit_visitor @@ -180,6 +210,10 @@ public: virtual bool VisitCallExpr(clang::CallExpr *expr); + virtual bool VisitLambdaExpr(clang::LambdaExpr *expr); + + virtual bool TraverseLambdaExpr(clang::LambdaExpr *expr); + virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *method); virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *cls); @@ -200,7 +234,7 @@ public: call_expression_context &context(); - void finalize() { } + void finalize(); template common::optional_ref get_participant(const clang::Decl *decl) diff --git a/tests/t20012/.clang-uml b/tests/t20012/.clang-uml new file mode 100644 index 00000000..bd288238 --- /dev/null +++ b/tests/t20012/.clang-uml @@ -0,0 +1,14 @@ +compilation_database_dir: .. +output_directory: puml +diagrams: + t20012_sequence: + type: sequence + glob: + - ../../tests/t20012/t20012.cc + include: + namespaces: + - clanguml::t20012 + using_namespace: + - clanguml::t20012 + start_from: + - function: "clanguml::t20012::tmain()" \ No newline at end of file diff --git a/tests/t20012/t20012.cc b/tests/t20012/t20012.cc new file mode 100644 index 00000000..488211a3 --- /dev/null +++ b/tests/t20012/t20012.cc @@ -0,0 +1,73 @@ +#include +#include + +namespace clanguml { +namespace t20012 { +struct A { + void a() { aa(); } + + void aa() { aaa(); } + + void aaa() { } +}; + +struct B { + void b() { bb(); } + + void bb() { bbb(); } + + void bbb() { } +}; + +struct C { + void c() { cc(); } + + void cc() { ccc(); } + + void ccc() { } +}; + +template struct R { + R(F &&f) + : f_{std::move(f)} + { + } + + void r() { f_(); } + + F f_; +}; + +void tmain() +{ + A a; + B b; + C c; + + // The activity shouldn't be marked at the lambda definition, but + // wherever it is actually called... + auto alambda = [&a, &b]() { + a.a(); + b.b(); + }; + + // ...like here + alambda(); + + // There should be no call to B in the sequence diagram as the blambda + // is never called + [[maybe_unused]] auto blambda = [&b]() { b.b(); }; + + // Nested lambdas should also work + auto clambda = [alambda, &c]() { + c.c(); + alambda(); + }; + clambda(); + + R r{[&c]() { c.c(); }}; + + r.r(); +} +} +} \ No newline at end of file diff --git a/tests/t20012/test_case.h b/tests/t20012/test_case.h new file mode 100644 index 00000000..16681581 --- /dev/null +++ b/tests/t20012/test_case.h @@ -0,0 +1,66 @@ +/** + * tests/t20012/test_case.h + * + * Copyright (c) 2021-2022 Bartek Kryza + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +TEST_CASE("t20012", "[test-case][sequence]") +{ + auto [config, db] = load_config("t20012"); + + auto diagram = config.diagrams["t20012_sequence"]; + + REQUIRE(diagram->name == "t20012_sequence"); + + auto model = generate_sequence_diagram(*db, diagram); + + REQUIRE(model->name() == "t20012_sequence"); + + auto puml = generate_sequence_puml(diagram, *model); + AliasMatcher _A(puml); + + REQUIRE_THAT(puml, StartsWith("@startuml")); + REQUIRE_THAT(puml, EndsWith("@enduml\n")); + + // Check if all calls exist + REQUIRE_THAT(puml, + HasCall(_A("tmain()"), _A("tmain()##(lambda 49:20)"), "operator()")); + REQUIRE_THAT(puml, HasCall(_A("tmain()##(lambda 49:20)"), _A("A"), "a")); + REQUIRE_THAT(puml, HasCall(_A("A"), _A("A"), "aa")); + REQUIRE_THAT(puml, HasCall(_A("A"), _A("A"), "aaa")); + + REQUIRE_THAT(puml, HasCall(_A("tmain()##(lambda 49:20)"), _A("B"), "b")); + REQUIRE_THAT(puml, HasCall(_A("B"), _A("B"), "bb")); + REQUIRE_THAT(puml, HasCall(_A("B"), _A("B"), "bbb")); + + REQUIRE_THAT(puml, HasCall(_A("tmain()##(lambda 62:20)"), _A("C"), "c")); + REQUIRE_THAT(puml, HasCall(_A("C"), _A("C"), "cc")); + REQUIRE_THAT(puml, HasCall(_A("C"), _A("C"), "ccc")); + REQUIRE_THAT(puml, + HasCall(_A("tmain()##(lambda 62:20)"), _A("tmain()##(lambda 49:20)"), + "operator()")); + + REQUIRE_THAT(puml, HasCall(_A("C"), _A("C"), "ccc")); + + REQUIRE_THAT(puml, HasCall(_A("tmain()"), _A("R"), "r")); + REQUIRE_THAT(puml, + HasCall(_A("R"), _A("tmain()##(lambda 68:9)"), + "operator()")); + REQUIRE_THAT( + puml, HasCall(_A("tmain()##(lambda 68:9)"), _A("C"), "c")); + + save_puml( + "./" + config.output_directory() + "/" + diagram->name + ".puml", puml); +} \ No newline at end of file diff --git a/tests/test_cases.cc b/tests/test_cases.cc index 287a9539..af89443c 100644 --- a/tests/test_cases.cc +++ b/tests/test_cases.cc @@ -258,6 +258,7 @@ using namespace clanguml::test::matchers; #include "t20009/test_case.h" #include "t20010/test_case.h" #include "t20011/test_case.h" +#include "t20012/test_case.h" /// /// Package diagram tests