Adding handling of lambda expressions in sequence diagrams

This commit is contained in:
Bartek Kryza
2022-12-04 01:33:02 +01:00
parent 459baa326c
commit d1d4d5e0e7
9 changed files with 416 additions and 8 deletions

View File

@@ -78,10 +78,10 @@ void generator::generate_return(const message &m, std::ostream &ostr) const
} }
void generator::generate_activity(const activity &a, std::ostream &ostr, void generator::generate_activity(const activity &a, std::ostream &ostr,
std::set<common::model::diagram_element::id_t> &visited) const std::vector<common::model::diagram_element::id_t> &visited) const
{ {
for (const auto &m : a.messages) { for (const auto &m : a.messages) {
visited.emplace(m.from); visited.push_back(m.from);
const auto &to = m_model.get_participant<model::participant>(m.to); const auto &to = m_model.get_participant<model::participant>(m.to);
if (!to) if (!to)
@@ -94,12 +94,16 @@ void generator::generate_activity(const activity &a, std::ostream &ostr,
ostr << "activate " << to.value().alias() << std::endl; ostr << "activate " << to.value().alias() << std::endl;
if (m_model.sequences.find(m.to) != m_model.sequences.end()) { if (m_model.sequences.find(m.to) != m_model.sequences.end()) {
if (visited.find(m.to) == if (std::find(visited.begin(), visited.end(), m.to) ==
visited.end()) { // break infinite recursion on recursive calls visited.end()) { // break infinite recursion on recursive calls
LOG_DBG("Creating activity {} --> {} - missing sequence {}", LOG_DBG("Creating activity {} --> {} - missing sequence {}",
m.from, m.to, m.to); m.from, m.to, m.to);
generate_activity(m_model.sequences[m.to], ostr, visited); generate_activity(m_model.sequences[m.to], ostr, visited);
} }
// else {
// // clear the visited list after breaking the loop
// visited.clear();
// }
} }
else else
LOG_DBG("Skipping activity {} --> {} - missing sequence {}", m.from, LOG_DBG("Skipping activity {} --> {} - missing sequence {}", m.from,
@@ -107,6 +111,8 @@ void generator::generate_activity(const activity &a, std::ostream &ostr,
generate_return(m, ostr); generate_return(m, ostr);
visited.pop_back();
ostr << "deactivate " << to.value().alias() << std::endl; ostr << "deactivate " << to.value().alias() << std::endl;
} }
} }
@@ -182,9 +188,18 @@ void generator::generate(std::ostream &ostr) const
break; break;
} }
} }
std::set<common::model::diagram_element::id_t> visited_participants; std::vector<common::model::diagram_element::id_t> visited_participants;
const auto& from = m_model.get_participant<model::participant>(start_from);
generate_participant(ostr, start_from);
ostr << "activate " << from.value().alias() << std::endl;
generate_activity( generate_activity(
m_model.sequences[start_from], ostr, visited_participants); m_model.sequences[start_from], ostr, visited_participants);
ostr << "deactivate " << from.value().alias() << std::endl;
} }
else { else {
// TODO: Add support for other sequence start location types // TODO: Add support for other sequence start location types

View File

@@ -56,7 +56,7 @@ public:
void generate_activity(const clanguml::sequence_diagram::model::activity &a, void generate_activity(const clanguml::sequence_diagram::model::activity &a,
std::ostream &ostr, std::ostream &ostr,
std::set<common::model::diagram_element::id_t> &visited) const; std::vector<common::model::diagram_element::id_t> &visited) const;
void generate(std::ostream &ostr) const; void generate(std::ostream &ostr) const;

View File

@@ -115,11 +115,16 @@ public:
void is_alias(bool alias) { is_alias_ = alias; } void is_alias(bool alias) { is_alias_ = alias; }
bool is_lambda() const { return is_lambda_; }
void is_lambda(bool is_lambda) { is_lambda_ = is_lambda; }
private: private:
bool is_struct_{false}; bool is_struct_{false};
bool is_template_{false}; bool is_template_{false};
bool is_template_instantiation_{false}; bool is_template_instantiation_{false};
bool is_alias_{false}; bool is_alias_{false};
bool is_lambda_{false};
std::map<std::string, clanguml::class_diagram::model::type_alias> std::map<std::string, clanguml::class_diagram::model::type_alias>
type_aliases_; type_aliases_;
@@ -127,6 +132,12 @@ private:
std::string full_name_; std::string full_name_;
}; };
struct lambda : public class_ {
using class_::class_;
std::string type_name() const override { return "lambda"; }
};
struct function : public participant { struct function : public participant {
function(const common::model::namespace_ &using_namespace); function(const common::model::namespace_ &using_namespace);

View File

@@ -99,6 +99,9 @@ bool translation_unit_visitor::VisitCXXRecordDecl(clang::CXXRecordDecl *cls)
if (cls->isLocalClass()) if (cls->isLocalClass())
return true; return true;
LOG_DBG("Visiting class declaration at {}",
cls->getBeginLoc().printToString(source_manager()));
// Build the class declaration and store it in the diagram, even // Build the class declaration and store it in the diagram, even
// if we don't need it for any of the participants of this diagram // if we don't need it for any of the participants of this diagram
auto c_ptr = create_class_declaration(cls); auto c_ptr = create_class_declaration(cls);
@@ -400,6 +403,73 @@ bool translation_unit_visitor::VisitFunctionTemplateDecl(
return true; return true;
} }
bool translation_unit_visitor::VisitLambdaExpr(clang::LambdaExpr *expr)
{
const auto lambda_full_name =
expr->getLambdaClass()->getCanonicalDecl()->getNameAsString();
LOG_DBG("Visiting lambda expression {} at {}", lambda_full_name,
expr->getBeginLoc().printToString(source_manager()));
LOG_DBG("Lambda call operator ID {} - lambda class ID {}, class call "
"operator ID {}",
expr->getCallOperator()->getID(), expr->getLambdaClass()->getID(),
expr->getLambdaClass()->getLambdaCallOperator()->getID());
// Create lambda class participant
auto *cls = expr->getLambdaClass();
auto c_ptr = create_class_declaration(cls);
if (!c_ptr)
return true;
const auto cls_id = c_ptr->id();
set_unique_id(cls->getID(), cls_id);
// Create lambda class operator() participant
auto m_ptr = std::make_unique<sequence_diagram::model::method>(
config().using_namespace());
common::model::namespace_ ns{c_ptr->get_namespace()};
auto method_name = "operator()";
m_ptr->set_method_name(method_name);
ns.pop_back();
m_ptr->set_class_id(cls_id);
m_ptr->set_class_full_name(c_ptr->full_name(false));
diagram().add_participant(std::move(c_ptr));
m_ptr->set_id(common::to_id(
get_participant(cls_id).value().full_name(false) + "::" + method_name));
context().enter_lambda_expression(m_ptr->id());
set_unique_id(expr->getCallOperator()->getID(), m_ptr->id());
diagram().add_participant(std::move(m_ptr));
[[maybe_unused]] const auto is_generic_lambda = expr->isGenericLambda();
return true;
}
bool translation_unit_visitor::TraverseLambdaExpr(clang::LambdaExpr *expr)
{
const auto lambda_full_name =
expr->getLambdaClass()->getCanonicalDecl()->getNameAsString();
RecursiveASTVisitor<translation_unit_visitor>::TraverseLambdaExpr(expr);
LOG_DBG("Leaving lambda expression {} at {}", lambda_full_name,
expr->getBeginLoc().printToString(source_manager()));
context().leave_lambda_expression();
return true;
}
bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr) bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr)
{ {
using clanguml::common::model::message_t; using clanguml::common::model::message_t;
@@ -424,6 +494,10 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr)
m.type = message_t::kCall; m.type = message_t::kCall;
m.from = context().caller_id(); m.from = context().caller_id();
if (context().lambda_caller_id() != 0) {
m.from = context().lambda_caller_id();
}
const auto &current_ast_context = *context().get_ast_context(); const auto &current_ast_context = *context().get_ast_context();
LOG_DBG("Visiting call expression at {}", LOG_DBG("Visiting call expression at {}",
@@ -433,6 +507,26 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr)
clang::dyn_cast_or_null<clang::CXXOperatorCallExpr>(expr); clang::dyn_cast_or_null<clang::CXXOperatorCallExpr>(expr);
operator_call_expr != nullptr) { operator_call_expr != nullptr) {
// TODO: Handle C++ operator calls // TODO: Handle C++ operator calls
LOG_DBG("Operator call expression to {} at {}",
expr->getCalleeDecl()->getID(),
expr->getBeginLoc().printToString(source_manager()));
auto maybe_id = get_unique_id(expr->getCalleeDecl()->getID());
if (maybe_id.has_value()) {
// Found operator() call to a participant
// auto maybe_participant = get_participant(maybe_id.value());
// if (maybe_participant.has_value()) {
m.to = maybe_id.value();
m.message_name = "operator()";
//}
}
else {
m.to = expr->getCalleeDecl()->getID();
m.message_name = "operator()";
}
if (clang::dyn_cast<clang::ImplicitCastExpr>(expr)) { }
} }
// //
// Call to a class method // Call to a class method
@@ -637,6 +731,9 @@ bool translation_unit_visitor::VisitCallExpr(clang::CallExpr *expr)
auto *ftd = clang::dyn_cast_or_null< auto *ftd = clang::dyn_cast_or_null<
clang::FunctionTemplateDecl>(decl); clang::FunctionTemplateDecl>(decl);
if (!get_unique_id(ftd->getID()).has_value())
continue;
m.to = get_unique_id(ftd->getID()).value(); m.to = get_unique_id(ftd->getID()).value();
auto message_name = auto message_name =
diagram() diagram()
@@ -751,11 +848,16 @@ translation_unit_visitor::create_class_declaration(clang::CXXRecordDecl *cls)
auto qualified_name = auto qualified_name =
cls->getQualifiedNameAsString(); // common::get_qualified_name(*cls); cls->getQualifiedNameAsString(); // common::get_qualified_name(*cls);
if (!diagram().should_include(qualified_name)) if (!cls->isLambda())
return {}; if (!diagram().should_include(qualified_name))
return {};
auto ns = common::get_tag_namespace(*cls); auto ns = common::get_tag_namespace(*cls);
if (cls->isLambda() &&
!diagram().should_include(ns.to_string() + "::lambda"))
return {};
const auto *parent = cls->getParent(); const auto *parent = cls->getParent();
if (parent && parent->isRecord()) { if (parent && parent->isRecord()) {
@@ -815,6 +917,23 @@ translation_unit_visitor::create_class_declaration(clang::CXXRecordDecl *cls)
c.nested(true); c.nested(true);
} }
else if (cls->isLambda()) {
c.is_lambda(true);
if (cls->getParent()) {
auto parent_full_name = get_participant(context().caller_id())
.value()
.full_name_no_ns();
const auto location = cls->getLocation();
const auto type_name =
fmt::format("{}##(lambda {}:{})", parent_full_name,
source_manager().getSpellingLineNumber(location),
source_manager().getSpellingColumnNumber(location));
c.set_name(type_name);
c.set_namespace(ns);
c.set_id(common::to_id(c.full_name(false)));
}
}
else { else {
c.set_name(common::get_tag_name(*cls)); c.set_name(common::get_tag_name(*cls));
c.set_namespace(ns); c.set_namespace(ns);
@@ -1208,6 +1327,35 @@ void translation_unit_visitor::process_template_specialization_argument(
simplify_system_template(argument, simplify_system_template(argument,
argument.to_string(config().using_namespace(), false)); argument.to_string(config().using_namespace(), false));
} }
else if (arg.getAsType()->getAsCXXRecordDecl()) {
if (arg.getAsType()->getAsCXXRecordDecl()->isLambda()) {
if (get_unique_id(
arg.getAsType()->getAsCXXRecordDecl()->getID())
.has_value()) {
argument.set_name(get_participant(
get_unique_id(
arg.getAsType()->getAsCXXRecordDecl()->getID())
.value())
.value()
.full_name(false));
}
else {
auto parent_full_name =
get_participant(context().caller_id())
.value()
.full_name_no_ns();
const auto location =
arg.getAsType()->getAsCXXRecordDecl()->getLocation();
const auto type_name =
fmt::format("{}##(lambda {}:{})", parent_full_name,
source_manager().getSpellingLineNumber(location),
source_manager().getSpellingColumnNumber(location));
argument.set_name(type_name);
}
}
}
else if (arg.getAsType()->getAs<clang::TemplateTypeParmType>()) { else if (arg.getAsType()->getAs<clang::TemplateTypeParmType>()) {
auto type_name = auto type_name =
common::to_string(arg.getAsType(), cls->getASTContext()); common::to_string(arg.getAsType(), cls->getASTContext());
@@ -1312,6 +1460,41 @@ void translation_unit_visitor::process_template_specialization_argument(
cls->getLocation().dump(source_manager()); cls->getLocation().dump(source_manager());
} }
// else if (arg.getKind() == clang::TemplateArgument::Expression) {
// if (clang::dyn_cast<clang::LambdaExpr>(arg.getAsExpr()) !=
// nullptr) {
// class_diagram::model::template_parameter argument;
//// const auto location =
//// arg.getAsType()->getAsCXXRecordDecl()->getLocation();
////
//// auto type_name = fmt::format("(lambda {}:{}:{})",
//// source_manager().getFilename(location).str(),
//// source_manager().getSpellingLineNumber(location),
//// source_manager().getSpellingColumnNumber(location));
////
//// argument.set_name(type_name);
//
// if (get_unique_id(
// arg.getAsType()->getAsCXXRecordDecl()->getID())
// .has_value()) {
// argument.set_name(get_participant(
// get_unique_id(
// arg.getAsType()->getAsCXXRecordDecl()->getID())
// .value())
// .value()
// .full_name(false));
// }
// else {
// const auto location =
// arg.getAsType()->getAsCXXRecordDecl()->getLocation();
// auto type_name = fmt::format("(lambda {}:{}:{})",
// source_manager().getFilename(location).str(),
// source_manager().getSpellingLineNumber(location),
// source_manager().getSpellingColumnNumber(location));
// argument.set_name(type_name);
// }
// }
// }
else if (argument_kind == clang::TemplateArgument::Pack) { else if (argument_kind == clang::TemplateArgument::Pack) {
// This will only work for now if pack is at the end // This will only work for now if pack is at the end
size_t argument_pack_index{argument_index}; size_t argument_pack_index{argument_index};
@@ -1534,4 +1717,15 @@ bool translation_unit_visitor::simplify_system_template(
else else
return false; return false;
} }
void translation_unit_visitor::finalize()
{
for (auto &[id, activity] : diagram().sequences) {
for (auto &m : activity.messages) {
if (local_ast_id_map_.find(m.to) != local_ast_id_map_.end()) {
m.to = local_ast_id_map_.at(m.to);
}
}
}
}
} }

View File

@@ -25,6 +25,8 @@
#include <clang/AST/RecursiveASTVisitor.h> #include <clang/AST/RecursiveASTVisitor.h>
#include <clang/Basic/SourceManager.h> #include <clang/Basic/SourceManager.h>
#include <stack>
namespace clanguml::sequence_diagram::visitor { namespace clanguml::sequence_diagram::visitor {
std::string to_string(const clang::FunctionTemplateDecl *decl); std::string to_string(const clang::FunctionTemplateDecl *decl);
@@ -150,12 +152,39 @@ struct call_expression_context {
std::int64_t caller_id() const { return current_caller_id_; } std::int64_t caller_id() const { return current_caller_id_; }
std::int64_t lambda_caller_id() const
{
if(current_lambda_caller_id_.empty())
return 0;
return current_lambda_caller_id_.top();
}
void set_caller_id(std::int64_t id) void set_caller_id(std::int64_t id)
{ {
LOG_DBG("Setting current caller id to {}", id); LOG_DBG("Setting current caller id to {}", id);
current_caller_id_ = id; current_caller_id_ = id;
} }
void enter_lambda_expression(std::int64_t id)
{
LOG_DBG("Setting current lambda caller id to {}", id);
assert(id != 0);
current_lambda_caller_id_.push(id);
}
void leave_lambda_expression()
{
assert(!current_lambda_caller_id_.empty());
LOG_DBG("Leaving current lambda expression id to {}",
current_lambda_caller_id_.top());
current_lambda_caller_id_.pop();
}
clang::CXXRecordDecl *current_class_decl_; clang::CXXRecordDecl *current_class_decl_;
clang::ClassTemplateDecl *current_class_template_decl_; clang::ClassTemplateDecl *current_class_template_decl_;
clang::ClassTemplateSpecializationDecl clang::ClassTemplateSpecializationDecl
@@ -166,6 +195,7 @@ struct call_expression_context {
private: private:
std::int64_t current_caller_id_; std::int64_t current_caller_id_;
std::stack<std::int64_t> current_lambda_caller_id_;
}; };
class translation_unit_visitor class translation_unit_visitor
@@ -180,6 +210,10 @@ public:
virtual bool VisitCallExpr(clang::CallExpr *expr); virtual bool VisitCallExpr(clang::CallExpr *expr);
virtual bool VisitLambdaExpr(clang::LambdaExpr *expr);
virtual bool TraverseLambdaExpr(clang::LambdaExpr *expr);
virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *method); virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *method);
virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *cls); virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *cls);
@@ -200,7 +234,7 @@ public:
call_expression_context &context(); call_expression_context &context();
void finalize() { } void finalize();
template <typename T = model::participant> template <typename T = model::participant>
common::optional_ref<T> get_participant(const clang::Decl *decl) common::optional_ref<T> get_participant(const clang::Decl *decl)

14
tests/t20012/.clang-uml Normal file
View File

@@ -0,0 +1,14 @@
compilation_database_dir: ..
output_directory: puml
diagrams:
t20012_sequence:
type: sequence
glob:
- ../../tests/t20012/t20012.cc
include:
namespaces:
- clanguml::t20012
using_namespace:
- clanguml::t20012
start_from:
- function: "clanguml::t20012::tmain()"

73
tests/t20012/t20012.cc Normal file
View File

@@ -0,0 +1,73 @@
#include <functional>
#include <utility>
namespace clanguml {
namespace t20012 {
struct A {
void a() { aa(); }
void aa() { aaa(); }
void aaa() { }
};
struct B {
void b() { bb(); }
void bb() { bbb(); }
void bbb() { }
};
struct C {
void c() { cc(); }
void cc() { ccc(); }
void ccc() { }
};
template <typename F> struct R {
R(F &&f)
: f_{std::move(f)}
{
}
void r() { f_(); }
F f_;
};
void tmain()
{
A a;
B b;
C c;
// The activity shouldn't be marked at the lambda definition, but
// wherever it is actually called...
auto alambda = [&a, &b]() {
a.a();
b.b();
};
// ...like here
alambda();
// There should be no call to B in the sequence diagram as the blambda
// is never called
[[maybe_unused]] auto blambda = [&b]() { b.b(); };
// Nested lambdas should also work
auto clambda = [alambda, &c]() {
c.c();
alambda();
};
clambda();
R r{[&c]() { c.c(); }};
r.r();
}
}
}

66
tests/t20012/test_case.h Normal file
View File

@@ -0,0 +1,66 @@
/**
* tests/t20012/test_case.h
*
* Copyright (c) 2021-2022 Bartek Kryza <bkryza@gmail.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
TEST_CASE("t20012", "[test-case][sequence]")
{
auto [config, db] = load_config("t20012");
auto diagram = config.diagrams["t20012_sequence"];
REQUIRE(diagram->name == "t20012_sequence");
auto model = generate_sequence_diagram(*db, diagram);
REQUIRE(model->name() == "t20012_sequence");
auto puml = generate_sequence_puml(diagram, *model);
AliasMatcher _A(puml);
REQUIRE_THAT(puml, StartsWith("@startuml"));
REQUIRE_THAT(puml, EndsWith("@enduml\n"));
// Check if all calls exist
REQUIRE_THAT(puml,
HasCall(_A("tmain()"), _A("tmain()##(lambda 49:20)"), "operator()"));
REQUIRE_THAT(puml, HasCall(_A("tmain()##(lambda 49:20)"), _A("A"), "a"));
REQUIRE_THAT(puml, HasCall(_A("A"), _A("A"), "aa"));
REQUIRE_THAT(puml, HasCall(_A("A"), _A("A"), "aaa"));
REQUIRE_THAT(puml, HasCall(_A("tmain()##(lambda 49:20)"), _A("B"), "b"));
REQUIRE_THAT(puml, HasCall(_A("B"), _A("B"), "bb"));
REQUIRE_THAT(puml, HasCall(_A("B"), _A("B"), "bbb"));
REQUIRE_THAT(puml, HasCall(_A("tmain()##(lambda 62:20)"), _A("C"), "c"));
REQUIRE_THAT(puml, HasCall(_A("C"), _A("C"), "cc"));
REQUIRE_THAT(puml, HasCall(_A("C"), _A("C"), "ccc"));
REQUIRE_THAT(puml,
HasCall(_A("tmain()##(lambda 62:20)"), _A("tmain()##(lambda 49:20)"),
"operator()"));
REQUIRE_THAT(puml, HasCall(_A("C"), _A("C"), "ccc"));
REQUIRE_THAT(puml, HasCall(_A("tmain()"), _A("R<R##(lambda 68:9)>"), "r"));
REQUIRE_THAT(puml,
HasCall(_A("R<R##(lambda 68:9)>"), _A("tmain()##(lambda 68:9)"),
"operator()"));
REQUIRE_THAT(
puml, HasCall(_A("tmain()##(lambda 68:9)"), _A("C"), "c"));
save_puml(
"./" + config.output_directory() + "/" + diagram->name + ".puml", puml);
}

View File

@@ -258,6 +258,7 @@ using namespace clanguml::test::matchers;
#include "t20009/test_case.h" #include "t20009/test_case.h"
#include "t20010/test_case.h" #include "t20010/test_case.h"
#include "t20011/test_case.h" #include "t20011/test_case.h"
#include "t20012/test_case.h"
/// ///
/// Package diagram tests /// Package diagram tests