7#include "dlinear/symbolic/symbolic_expression.h"
9namespace dlinear::drake::symbolic {
25template<
typename Result,
typename Visitor,
typename... Args>
26Result VisitPolynomial(Visitor *v,
const Expression &e, Args &&... args) {
27 assert(e.is_polynomial());
28 switch (e.get_kind()) {
29 case ExpressionKind::Constant:
return v->VisitConstant(e, std::forward<Args>(args)...);
31 case ExpressionKind::Var:
return v->VisitVariable(e, std::forward<Args>(args)...);
33 case ExpressionKind::Add:
return v->VisitAddition(e, std::forward<Args>(args)...);
35 case ExpressionKind::Mul:
return v->VisitMultiplication(e, std::forward<Args>(args)...);
37 case ExpressionKind::Div:
return v->VisitDivision(e, std::forward<Args>(args)...);
39 case ExpressionKind::Pow:
return v->VisitPow(e, std::forward<Args>(args)...);
41 case ExpressionKind::NaN:
throw std::runtime_error(
"NaN is detected while visiting an expression.");
43 case ExpressionKind::Infty:
throw std::runtime_error(
"An infinity is detected while visiting an expression.");
45 case ExpressionKind::Log:
46 case ExpressionKind::Abs:
47 case ExpressionKind::Exp:
48 case ExpressionKind::Sqrt:
49 case ExpressionKind::Sin:
50 case ExpressionKind::Cos:
51 case ExpressionKind::Tan:
52 case ExpressionKind::Asin:
53 case ExpressionKind::Acos:
54 case ExpressionKind::Atan:
55 case ExpressionKind::Atan2:
56 case ExpressionKind::Sinh:
57 case ExpressionKind::Cosh:
58 case ExpressionKind::Tanh:
59 case ExpressionKind::Min:
60 case ExpressionKind::Max:
61 case ExpressionKind::IfThenElse:
62 case ExpressionKind::UninterpretedFunction:
65 throw std::runtime_error(
"Should not be reachable.");
69 throw std::runtime_error(
"Should not be reachable.");
89template<
typename Result,
typename Visitor,
typename... Args>
90Result VisitExpression(Visitor *v,
const Expression &e, Args &&... args) {
91 switch (e.get_kind()) {
92 case ExpressionKind::Constant:
return v->VisitConstant(e, std::forward<Args>(args)...);
94 case ExpressionKind::Var:
return v->VisitVariable(e, std::forward<Args>(args)...);
96 case ExpressionKind::Add:
return v->VisitAddition(e, std::forward<Args>(args)...);
98 case ExpressionKind::Mul:
return v->VisitMultiplication(e, std::forward<Args>(args)...);
100 case ExpressionKind::Div:
return v->VisitDivision(e, std::forward<Args>(args)...);
102 case ExpressionKind::Log:
return v->VisitLog(e, std::forward<Args>(args)...);
104 case ExpressionKind::Abs:
return v->VisitAbs(e, std::forward<Args>(args)...);
106 case ExpressionKind::Exp:
return v->VisitExp(e, std::forward<Args>(args)...);
108 case ExpressionKind::Sqrt:
return v->VisitSqrt(e, std::forward<Args>(args)...);
110 case ExpressionKind::Pow:
return v->VisitPow(e, std::forward<Args>(args)...);
112 case ExpressionKind::Sin:
return v->VisitSin(e, std::forward<Args>(args)...);
114 case ExpressionKind::Cos:
return v->VisitCos(e, std::forward<Args>(args)...);
116 case ExpressionKind::Tan:
return v->VisitTan(e, std::forward<Args>(args)...);
118 case ExpressionKind::Asin:
return v->VisitAsin(e, std::forward<Args>(args)...);
120 case ExpressionKind::Acos:
return v->VisitAcos(e, std::forward<Args>(args)...);
122 case ExpressionKind::Atan:
return v->VisitAtan(e, std::forward<Args>(args)...);
124 case ExpressionKind::Atan2:
return v->VisitAtan2(e, std::forward<Args>(args)...);
126 case ExpressionKind::Sinh:
return v->VisitSinh(e, std::forward<Args>(args)...);
128 case ExpressionKind::Cosh:
return v->VisitCosh(e, std::forward<Args>(args)...);
130 case ExpressionKind::Tanh:
return v->VisitTanh(e, std::forward<Args>(args)...);
132 case ExpressionKind::Min:
return v->VisitMin(e, std::forward<Args>(args)...);
134 case ExpressionKind::Max:
return v->VisitMax(e, std::forward<Args>(args)...);
136 case ExpressionKind::IfThenElse:
return v->VisitIfThenElse(e, std::forward<Args>(args)...);
138 case ExpressionKind::Infty:
throw std::runtime_error(
"An infinity is detected while visiting an expression.");
140 case ExpressionKind::NaN:
throw std::runtime_error(
"NaN is detected while visiting an expression.");
142 case ExpressionKind::UninterpretedFunction:
return v->VisitUninterpretedFunction(e, std::forward<Args>(args)...);
146 throw std::runtime_error(
"Should not be reachable.");