22static_assert(std::endian::native == std::endian::little,
"Only little-endian systems are supported for onnx parsing");
25inline void invalid_number_of_inputs([[maybe_unused]] const ::onnx::NodeProto& node,
26 [[maybe_unused]]
const int actualNumberOfInputs,
27 [[maybe_unused]]
const int lowerBound, [[maybe_unused]]
const int upperBound) {
28 if (lowerBound == upperBound) {
29 DLINEAR_RUNTIME_ERROR_FMT(
"Onnx operation '{}' expected to have exactly {} inputs, but found {}", node.op_type(),
30 lowerBound, actualNumberOfInputs);
32 DLINEAR_RUNTIME_ERROR_FMT(
"Onnx operation '{}' expected to have between {} and {} inputs, but found {}",
33 node.op_type(), lowerBound, upperBound, actualNumberOfInputs);
37inline const ::onnx::AttributeProto* FindAttribute(const ::onnx::NodeProto& node,
const std::string& name,
38 ::onnx::AttributeProto_AttributeType expectedType,
39 bool throw_on_missing =
false) {
40 for (const ::onnx::AttributeProto& attr : node.attribute()) {
41 if (attr.name() == name) {
42 if (attr.type() != expectedType) {
43 DLINEAR_RUNTIME_ERROR_FMT(
"Attribute '{}' must be of type {}", name,
44 AttributeProto_AttributeType_Name(expectedType));
50 DLINEAR_RUNTIME_ERROR_FMT(
"Onnx node of type {} is missing the expected attribute {}", node.op_type(), name);
59template <IsAnyOf<
bool,
float, std::
int64_t, std::
string, std::vector<
float>, std::vector<std::
int64_t>,
60 std::vector<std::
string>, const ::onnx::TensorProto*>
62constexpr ::onnx::AttributeProto_AttributeType GetAttributeType() {
63 if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, std::int64_t>) {
64 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INT;
65 }
else if constexpr (std::is_same_v<T, float>) {
66 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT;
67 }
else if constexpr (std::is_same_v<T, std::string>) {
68 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING;
69 }
else if constexpr (std::is_same_v<T, std::vector<float>>) {
70 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS;
71 }
else if constexpr (std::is_same_v<T, std::vector<std::int64_t>>) {
72 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS;
73 }
else if constexpr (std::is_same_v<T, std::vector<std::string>>) {
74 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS;
75 }
else if constexpr (std::is_same_v<T, const ::onnx::TensorProto*>) {
76 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR;
78 DLINEAR_UNREACHABLE();
86template <IsAnyOf<
bool,
float, std::
int64_t, std::
string, std::vector<
float>, std::vector<std::
int64_t>,
87 std::vector<std::
string>, const ::onnx::TensorProto*>
89T GetAttributeValue(const ::onnx::AttributeProto* attr) {
90 DLINEAR_ASSERT(attr !=
nullptr,
"AttributeProto must not be null");
91 if constexpr (std::is_same_v<T, bool>) {
92 return attr->i() != 0;
93 }
else if constexpr (std::is_same_v<T, std::int64_t>) {
95 }
else if constexpr (std::is_same_v<T, float>) {
97 }
else if constexpr (std::is_same_v<T, std::string>) {
99 }
else if constexpr (std::is_same_v<T, std::vector<float>>) {
100 return std::vector<float>{attr->floats().begin(), attr->floats().end()};
101 }
else if constexpr (std::is_same_v<T, std::vector<std::int64_t>>) {
102 return std::vector<std::int64_t>{attr->ints().begin(), attr->ints().end()};
103 }
else if constexpr (std::is_same_v<T, std::vector<std::string>>) {
104 return std::vector<std::string>{attr->strings().begin(), attr->strings().end()};
105 }
else if constexpr (std::is_same_v<T, const ::onnx::TensorProto*>) {
108 DLINEAR_UNREACHABLE();
116 const bool res =
model_.ParseFromIstream(&in);
118 DLINEAR_ERROR(
"OnnxDriver::ParseStreamCore(): Failed to parse model from input stream");
126 std::ifstream input(filename, std::ios::binary);
127 if (!input.is_open()) {
128 DLINEAR_ERROR_FMT(
"OnnxDriver::ParseFile({}): Failed to open file", filename);
135 DLINEAR_TRACE(
"OnnxDriver::ParseGraph()");
136 if (!
model_.has_graph()) DLINEAR_RUNTIME_ERROR(
"ModelProto must have a graph");
137 std::unordered_set<std::string> initializers;
138 for (const ::onnx::TensorProto& tensor :
model_.graph().initializer()) {
140 initializers.insert(tensor.name());
142 for (const ::onnx::ValueInfoProto& input :
model_.graph().input()) {
145 for (const ::onnx::ValueInfoProto& output :
model_.graph().output())
AddValueInfo(output);
150template <IsAnyOf<
bool,
float, std::
int64_t, std::
string, std::vector<
float>, std::vector<std::
int64_t>,
151 std::vector<std::
string>, const ::onnx::TensorProto*>
154 const std::optional<T>& default_value)
const {
155 const ::onnx::AttributeProto*
const attr =
156 FindAttribute(node, name, GetAttributeType<T>(), !default_value.has_value());
157 return attr ==
nullptr ? default_value.value() : GetAttributeValue<T>(attr);
160void OnnxDriver::EnsureInput(const ::onnx::NodeProto& node,
const int lb,
const int ub) {
161 if (node.input_size() < lb || node.input_size() > ub) invalid_number_of_inputs(node, node.input_size(), lb, ub);
163void OnnxDriver::EnsureInput(const ::onnx::NodeProto& node,
const int exact) {
164 if (node.input_size() != exact) invalid_number_of_inputs(node, node.input_size(), exact, exact);
168 DLINEAR_ASSERT(tensor.has_name(),
"TensorProto must have a name");
169 DLINEAR_ASSERT(tensor.has_data_type(),
"TensorProto must have a data_type");
170 DLINEAR_TRACE_FMT(
"AddInitializer({})", tensor.name());
176 DLINEAR_DEBUG_FMT(
"AddFormula({})", output_name);
179 DLINEAR_TRACE_FMT(
"AddFormula({}): {} == {}", output_name, var_tensor, final_tensor);
180 for (
const Formula& f : (var_tensor == final_tensor)) Assert(f);
181 DLINEAR_TRACE_FMT(
"Added formula for {}. Current assertions: {}", output_name,
context_.
assertions());
185void OnnxDriver::AddNode<NodeOpType::Abs>(const ::onnx::NodeProto& node) {
186 DLINEAR_ASSERT(node.op_type() ==
"Abs",
"NodeProto must have an op_type of Abs");
187 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
188 EnsureInput(node, 1);
190 const std::string& input = node.input(0);
191 const std::string& output = node.output(0);
192 available_inputs_.emplace(output,
Tensor{available_inputs_.at(input)}.
Abs());
193 DLINEAR_DEBUG_FMT(
"Abs node: {} = |{}|", output, input);
194 DLINEAR_TRACE_FMT(
"{} = |{}|", available_inputs_.at(output), available_inputs_.at(input));
199void OnnxDriver::AddNode<NodeOpType::Add>(const ::onnx::NodeProto& node) {
200 DLINEAR_ASSERT(node.op_type() ==
"Add",
"NodeProto must have an op_type of Add");
201 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
202 EnsureInput(node, 2);
204 const std::string& input1 = node.input(0);
205 const std::string& input2 = node.input(1);
206 const std::string& output = node.output(0);
207 available_inputs_.emplace(output, available_inputs_.at(input1) + available_inputs_.at(input2));
208 DLINEAR_DEBUG_FMT(
"Add node: {} = {} + {}", output, input1, input2);
209 DLINEAR_TRACE_FMT(
"{} = {} + {}", available_inputs_.at(output), available_inputs_.at(input1),
210 available_inputs_.at(input2));
215void OnnxDriver::AddNode<NodeOpType::Concat>(const ::onnx::NodeProto& node) {
216 DLINEAR_ASSERT(node.op_type() ==
"Concat",
"NodeProto must have an op_type of Concat");
217 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
218 EnsureInput(node, 1, 2147483647);
220 const auto axis = GetAttribute<std::int64_t>(node,
"axis");
221 const std::string& output = node.output(0);
222 const std::string& input1 = node.input(0);
223 const std::vector<std::string> inputs(node.input().begin() + 1, node.input().end());
224 std::vector<Tensor> tensors;
225 tensors.reserve(inputs.size());
226 std::transform(inputs.begin(), inputs.end(), std::back_inserter(tensors),
227 [
this](
const std::string& input) { return available_inputs_.at(input); });
229 available_inputs_.emplace(output, Tensor{available_inputs_.at(input1)}.Concat(tensors, axis));
230 DLINEAR_DEBUG_FMT(
"Concat node: {} = concat({})", output, inputs);
231 DLINEAR_TRACE_FMT(
"{} = concat({})", available_inputs_.at(output), tensors);
236void OnnxDriver::AddNode<NodeOpType::Constant>(const ::onnx::NodeProto& node) {
237 DLINEAR_ASSERT(node.op_type() ==
"Constant",
"NodeProto must have an op_type of Constant");
238 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
239 DLINEAR_ASSERT(node.attribute_size() == 1,
"NodeProto must have exactly 1 attribute");
241 const std::string& output = node.output(0);
242 const ::onnx::AttributeProto& attr = node.attribute(0);
244 available_inputs_.emplace(output, Tensor{attr.t()});
245 }
else if (attr.has_f()) {
248 available_inputs_.emplace(output, std::move(c));
249 }
else if (attr.has_i()) {
252 available_inputs_.emplace(output, std::move(c));
254 DLINEAR_RUNTIME_ERROR(
"Constant node must have a tensor, float, or integer attribute");
256 DLINEAR_DEBUG_FMT(
"Constant node: {}", output);
257 DLINEAR_TRACE_FMT(
"{}", available_inputs_.at(output));
262void OnnxDriver::AddNode<NodeOpType::Conv>(const ::onnx::NodeProto& node) {
263 DLINEAR_ASSERT(node.op_type() ==
"Conv",
"NodeProto must have an op_type of Conv");
264 DLINEAR_ASSERT(node.input_size() == 2 || node.input_size() == 3,
"NodeProto must have [2-3] inputs");
265 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
267 const std::string& input1 = node.input(0);
268 const std::string& input2 = node.input(1);
270 const std::string& output = node.output(0);
271 const Tensor& x = available_inputs_.at(input1);
272 const Tensor& w = available_inputs_.at(input2);
274 std::string auto_pad{GetAttribute<std::string>(node,
"auto_pad",
"NOTSET")};
275 std::vector<std::int64_t> dilations{GetAttribute<std::vector<std::int64_t>>(node,
"dilations", {{1, 1}})};
276 const auto group = GetAttribute<std::int64_t>(node,
"group", 1);
277 std::vector<std::int64_t> kernel_shape{GetAttribute<std::vector<std::int64_t>>(
278 node,
"kernel_shape", {{w.values().shape().begin() + 2, w.values().shape().end()}})};
279 std::vector<std::int64_t> pads{GetAttribute<std::vector<std::int64_t>>(node,
"pads", {{0, 0, 0, 0}})};
280 std::vector<std::int64_t> strides{GetAttribute<std::vector<std::int64_t>>(node,
"strides", {{1, 1}})};
282 if (auto_pad !=
"NOTSET") {
284 pads.assign(2 * strides.size(), 0);
285 if (auto_pad !=
"VALID") {
286 for (std::size_t i = 0; i < strides.size(); ++i) {
287 const std::int64_t out_dim =
288 (
static_cast<std::int64_t
>(x.values().shape()[i + 2]) + strides[i] - 1) / strides[i];
289 const std::int64_t fks = kernel_shape[i] / 2;
290 const std::int64_t lks = kernel_shape[i] / 2 - (kernel_shape[i] & 1 ? 0 : 1);
291 const std::int64_t pad = out_dim * strides[i] + fks * dilations[i] + lks * dilations[i] -
292 static_cast<std::int64_t
>(x.values().shape()[i + 2]);
293 if (auto_pad ==
"SAME_LOWER") {
295 pads[i + strides.size()] = pad / 2 + (pad & 1);
296 }
else if (auto_pad ==
"SAME_UPPER") {
297 pads[i] = pad / 2 + (pad & 1);
298 pads[i + strides.size()] = pad / 2;
304 Tensor conv{x.Convolution(w, dilations, group, kernel_shape, pads, strides)};
305 if (node.input_size() > 2) {
306 Tensor& b = available_inputs_.at(node.input(2));
307 b.Reshape({1,
static_cast<std::int64_t
>(b.size()), 1, 1});
310 available_inputs_.emplace(output, std::move(conv));
312 DLINEAR_DEBUG_FMT(
"Conv node: {} <- conv({}, {}, {}, {}, {}, {}, {}, {})", output, input1, input2, auto_pad,
313 dilations, group, kernel_shape, pads, strides);
314 DLINEAR_TRACE_FMT(
"{} <- conv({}, {})", available_inputs_.at(output), x, w);
319void OnnxDriver::AddNode<NodeOpType::Dropout>(const ::onnx::NodeProto& node) {
320 DLINEAR_ASSERT(node.op_type() ==
"Dropout",
"NodeProto must have an op_type of Dropout");
321 DLINEAR_ASSERT(node.output_size() == 1 || node.output_size() == 2,
"NodeProto must have [1-2] output");
322 EnsureInput(node, 1, 3);
324 if (node.input_size() == 3 && available_inputs_.at(node.input(2)).values().size() > 0 &&
325 available_inputs_.at(node.input(2))[0] != 0) {
326 DLINEAR_RUNTIME_ERROR(
"training_mode must be false in Dropout node");
329 const std::string& input = node.input(0);
330 const std::string& output = node.output(0);
331 available_inputs_.emplace(output, available_inputs_.at(input));
332 DLINEAR_DEBUG_FMT(
"Dropout node: {} = {}", output, input);
333 DLINEAR_TRACE_FMT(
"{} = {}", available_inputs_.at(output), available_inputs_.at(input));
338void OnnxDriver::AddNode<NodeOpType::Flatten>(const ::onnx::NodeProto& node) {
339 DLINEAR_ASSERT(node.op_type() ==
"Flatten",
"NodeProto must have an op_type of Flatten");
340 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
341 EnsureInput(node, 1);
343 const std::string& input = node.input(0);
344 const std::string& output = node.output(0);
345 const std::int64_t axis = GetAttribute<std::int64_t>(node,
"axis", 1);
346 available_inputs_.emplace(output, Tensor{available_inputs_.at(input)}.Flatten(axis));
347 DLINEAR_DEBUG_FMT(
"Flatten node: {} <- {}", output, input);
348 DLINEAR_TRACE_FMT(
"{} <- {}", available_inputs_.at(output), available_inputs_.at(input));
353void OnnxDriver::AddNode<NodeOpType::Gather>(const ::onnx::NodeProto& node) {
354 DLINEAR_ASSERT(node.op_type() ==
"Gather",
"NodeProto must have an op_type of Gather");
355 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
356 EnsureInput(node, 2);
358 const std::string& input1 = node.input(0);
359 const std::string& input2 = node.input(1);
360 const std::string& output = node.output(0);
361 const std::int64_t axis = GetAttribute<std::int64_t>(node,
"axis", 0);
362 available_inputs_.emplace(output, available_inputs_.at(input1).Gather(available_inputs_.at(input2), axis));
364 DLINEAR_DEBUG_FMT(
"Gather node: {} = {}[{}, axis = {}]", output, input1, input2, axis);
365 DLINEAR_TRACE_FMT(
"{} = {}[{}, axis = {}]", available_inputs_.at(output), available_inputs_.at(input1),
366 available_inputs_.at(input2), axis);
371void OnnxDriver::AddNode<NodeOpType::Gemm>(const ::onnx::NodeProto& node) {
372 DLINEAR_ASSERT(node.op_type() ==
"Gemm",
"NodeProto must have an op_type of Abs");
373 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
374 EnsureInput(node, 2, 3);
376 const std::string& input1 = node.input(0);
377 const std::string& input2 = node.input(1);
378 const std::string& output = node.output(0);
379 const float alpha = GetAttribute<float>(node,
"alpha", 1);
380 const bool transA = GetAttribute<bool>(node,
"transA",
false);
381 const bool transB = GetAttribute<bool>(node,
"transB",
false);
383 Tensor A{available_inputs_.at(input1)};
384 if (transA) A.Transpose();
385 Tensor
B{available_inputs_.at(input2)};
386 if (transB)
B.Transpose();
387 Tensor gemm = A.MatMul(
B) * alpha;
389 if (node.attribute_size() == 2) {
390 DLINEAR_DEBUG_FMT(
"Gemm node: {} = {} * {} x {}", output, alpha, input1, input2);
391 DLINEAR_TRACE_FMT(
"{} = {} * {} x {}", gemm, alpha, available_inputs_.at(input1), available_inputs_.at(input2));
394 if (node.input_size() == 3) {
395 const auto beta = GetAttribute<float>(node,
"beta", 1);
396 const std::string& input3 = node.input(2);
397 gemm += available_inputs_.at(input3) * beta;
398 DLINEAR_DEBUG_FMT(
"Gemm node: {} = {} * {} x {} + {} * {}", output, alpha, input1, input2, beta, input3);
399 DLINEAR_TRACE_FMT(
"{} = {} * {} x {} + {} * {}", gemm, alpha, available_inputs_.at(input1),
400 available_inputs_.at(input2), beta, available_inputs_.at(input3));
403 available_inputs_.emplace(output, gemm);
409void OnnxDriver::AddNode<NodeOpType::Identity>(const ::onnx::NodeProto& node) {
410 DLINEAR_ASSERT(node.op_type() ==
"Identity",
"NodeProto must have an op_type of Identity");
411 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
412 EnsureInput(node, 1);
414 const std::string& input = node.input(0);
415 const std::string& output = node.output(0);
416 available_inputs_.emplace(output, available_inputs_.at(input));
417 DLINEAR_DEBUG_FMT(
"Identity node: {} = {}", output, input);
418 DLINEAR_TRACE_FMT(
"{} = {}", available_inputs_.at(output), available_inputs_.at(input));
423void OnnxDriver::AddNode<NodeOpType::LeakyRelu>(const ::onnx::NodeProto& node) {
424 DLINEAR_ASSERT(node.op_type() ==
"LeakyRelu",
"NodeProto must have an op_type of LeakyRelu");
425 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
426 EnsureInput(node, 1);
427 const float alpha = GetAttribute<float>(node,
"alpha", 0.01);
429 const std::string& input = node.input(0);
430 const std::string& output = node.output(0);
431 Tensor relu = Tensor{available_inputs_.at(input)};
434 std::size_t counter = 0;
435 relu.Elementwise([alpha, &counter, &input,
this](
const Expression& e) {
437 relu.Elementwise([alpha,
this](
const Expression& e) {
439 const Formula condition{e > 0};
441 if (is_true(condition)) {
443 }
else if (is_false(condition)) {
447 const Variable relu_var{fmt::format(
"{}_leaky_relu_{}", input, ++counter)};
449 const Variable relu_var{
"lr"};
452 context_.AssertPiecewiseLinearFunction(relu_var, e >= 0, e, alpha * e);
453 context_.AddGuidedConstraint(
454 std::make_unique<LeakyReluConstraint>(relu_var, e, alpha, context_.predicate_abstractor()));
455 return Expression{relu_var};
457 available_inputs_.emplace(output, relu);
458 DLINEAR_DEBUG_FMT(
"Relu node: {} = 0 if input < 0 else {} * {}", output, alpha, input);
459 DLINEAR_TRACE_FMT(
"{}", relu);
464void OnnxDriver::AddNode<NodeOpType::MatMul>(const ::onnx::NodeProto& node) {
465 DLINEAR_ASSERT(node.op_type() ==
"MatMul",
"NodeProto must have an op_type of MatMul");
466 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
467 EnsureInput(node, 2);
469 const std::string& input1 = node.input(0);
470 const std::string& input2 = node.input(1);
471 const std::string& output = node.output(0);
472 available_inputs_.emplace(output, available_inputs_.at(input1).MatMul(available_inputs_.at(input2)));
473 DLINEAR_DEBUG_FMT(
"MatMul node: {} = {} x {}", output, input1, input2);
474 DLINEAR_TRACE_FMT(
"{} = {} x {}", available_inputs_.at(output), available_inputs_.at(input1),
475 available_inputs_.at(input2));
480void OnnxDriver::AddNode<NodeOpType::Mul>(const ::onnx::NodeProto& node) {
481 DLINEAR_ASSERT(node.op_type() ==
"Mul",
"NodeProto must have an op_type of Mul");
482 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
483 EnsureInput(node, 2);
485 const std::string& input1 = node.input(0);
486 const std::string& input2 = node.input(1);
487 const std::string& output = node.output(0);
488 available_inputs_.emplace(output, available_inputs_.at(input1) * available_inputs_.at(input2));
489 DLINEAR_DEBUG_FMT(
"Mul node: {} = {} * {}", output, input1, input2);
490 DLINEAR_TRACE_FMT(
"{} = {} * {}", available_inputs_.at(output), available_inputs_.at(input1),
491 available_inputs_.at(input2));
496void OnnxDriver::AddNode<NodeOpType::Reshape>(const ::onnx::NodeProto& node) {
497 DLINEAR_ASSERT(node.op_type() ==
"Reshape",
"NodeProto must have an op_type of Reshape");
498 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
499 EnsureInput(node, 2);
501 const std::string& input1 = node.input(0);
502 const std::string& input2 = node.input(1);
503 const std::string& output = node.output(0);
504 const bool allow_zero = GetAttribute<bool>(node,
"allowzero",
false);
506 const Tensor& shape = available_inputs_.at(input2);
507 available_inputs_.emplace(output, Tensor{available_inputs_.at(input1)}.Reshape(shape, allow_zero));
508 DLINEAR_DEBUG_FMT(
"Reshape node: {} = reshape({}, {})", output, input1, input2);
509 DLINEAR_TRACE_FMT(
"{} = reshape({}, {})", available_inputs_.at(output), available_inputs_.at(input1),
510 available_inputs_.at(input2));
515void OnnxDriver::AddNode<NodeOpType::Relu>(const ::onnx::NodeProto& node) {
516 DLINEAR_ASSERT(node.op_type() ==
"Relu",
"NodeProto must have an op_type of Relu");
517 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
518 EnsureInput(node, 1);
520 const std::string& input = node.input(0);
521 const std::string& output = node.output(0);
522 Tensor relu = Tensor{available_inputs_.at(input)};
525 std::size_t counter = 0;
526 relu.Elementwise([&counter, &input,
this](
const Expression& e) {
528 relu.Elementwise([
this](
const Expression& e) {
530 const Formula condition{e > 0};
532 if (is_true(condition)) {
534 }
else if (is_false(condition)) {
535 return Expression{0};
538 const Variable relu_var{fmt::format(
"{}_relu_{}", input, ++counter)};
540 const Variable relu_var{
"r"};
543 context_.AssertPiecewiseLinearFunction(relu_var, e >= 0, e, 0);
544 context_.Assert(relu_var >= 0);
545 context_.AddGuidedConstraint(std::make_unique<ReluConstraint>(relu_var, e, context_.predicate_abstractor()));
546 return Expression{relu_var};
548 available_inputs_.emplace(output, relu);
549 DLINEAR_DEBUG_FMT(
"Relu node: {} = 0 if input < 0 else {}", output, input);
550 DLINEAR_TRACE_FMT(
"{}", relu);
555void OnnxDriver::AddNode<NodeOpType::Sign>(const ::onnx::NodeProto& node) {
556 DLINEAR_ASSERT(node.op_type() ==
"Sign",
"NodeProto must have an op_type of Sign");
557 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
558 EnsureInput(node, 1);
560 const std::string& input = node.input(0);
561 const std::string& output = node.output(0);
562 Tensor sign = Tensor{available_inputs_.at(input)};
564 sign.Elementwise([](
const Expression& e) {
return if_then_else(e == 0, 0, if_then_else(e >= 0, 1, -1)); });
565 available_inputs_.emplace(output, sign);
566 DLINEAR_DEBUG_FMT(
"Sign node: {} = Sign({})", output, input);
567 DLINEAR_TRACE_FMT(
"{}", sign);
572void OnnxDriver::AddNode<NodeOpType::Sigmoid>(const ::onnx::NodeProto& node) {
573 DLINEAR_ASSERT(node.op_type() ==
"Sigmoid",
"NodeProto must have an op_type of Sigmoid");
574 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
575 EnsureInput(node, 1);
577 const std::string& input = node.input(0);
578 const std::string& output = node.output(0);
579 Tensor relu = Tensor{available_inputs_.at(input)};
581 relu.Elementwise([](
const Expression& e) {
582 if (!is_constant(e)) DLINEAR_RUNTIME_ERROR(
"Cannot apply the sigmoid function to a non constant value");
583 return 1 / (1 + exp(-get_constant_value(e).get_d()));
585 available_inputs_.emplace(output, relu);
586 DLINEAR_DEBUG_FMT(
"Relu node: {} = 0 if input < 0 else {}", output, input);
587 DLINEAR_TRACE_FMT(
"{}", relu);
592void OnnxDriver::AddNode<NodeOpType::Slice>(const ::onnx::NodeProto& node) {
593 DLINEAR_ASSERT(node.op_type() ==
"Slice",
"NodeProto must have an op_type of Slice");
594 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
595 EnsureInput(node, 3, 5);
597 const std::string& data = node.input(0);
598 const std::string& starts = node.input(1);
599 const std::string& ends = node.input(2);
600 const std::string& axis = node.input_size() > 3 ? node.input(3) :
"";
601 const std::string& steps = node.input_size() > 4 ? node.input(4) :
"";
602 const std::string& output = node.output(0);
603 const std::vector<std::int64_t> starts_v =
static_cast<std::vector<std::int64_t>
>(available_inputs_.at(starts));
604 const std::vector<std::int64_t> ends_v =
static_cast<std::vector<std::int64_t>
>(available_inputs_.at(ends));
605 const std::vector<std::int64_t> axis_v =
606 axis.empty() ? std::vector<std::int64_t>{} :
static_cast<std::vector<std::int64_t>
>(available_inputs_.at(axis));
607 const std::vector<std::int64_t> steps_v =
608 steps.empty() ? std::vector<std::int64_t>{} :
static_cast<std::vector<std::int64_t>
>(available_inputs_.at(steps));
609 available_inputs_.emplace(output, available_inputs_.at(data).Slice(starts_v, ends_v, axis_v, steps_v));
611 DLINEAR_DEBUG_FMT(
"Slice node: {} = {}[{}:{}:{}:{}]", output, data, starts, ends, axis, steps);
612 DLINEAR_TRACE_FMT(
"{} = {}[{}:{}:{}:{}", available_inputs_.at(output), available_inputs_.at(data), starts_v, ends_v,
619void OnnxDriver::AddNode<NodeOpType::Softmax>(const ::onnx::NodeProto& node) {
620 DLINEAR_ASSERT(node.op_type() ==
"Softmax",
"NodeProto must have an op_type of Softmax");
621 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
622 EnsureInput(node, 1);
624 const std::string& input = node.input(0);
625 const std::string& output = node.output(0);
627 const Expression& max = available_inputs_.at(input).Max();
628 const xt::xarray<Expression> softmax_values{xt::exp((available_inputs_.at(input) - max).values())};
629 const std::int64_t axis = GetAttribute<std::int64_t>(node,
"axis", -1);
631 DLINEAR_ASSERT(std::for_each(available_inputs_.at(input).begin(), available_inputs_.at(input).end(),
632 [](
const Expression& e) { return is_constant(e); }),
633 "Softmax input must be constant");
635 xt::xarray<Expression> sum{xt::sum(softmax_values, axis)};
636 auto shape = available_inputs_.at(input).values().shape();
637 shape.at(axis < 0 ? shape.size() + axis : axis) = 1;
639 available_inputs_.emplace(output, softmax_values / sum);
640 DLINEAR_DEBUG_FMT(
"Softmax node: {} = softmax({}, axis = {})", output, input, axis);
641 DLINEAR_TRACE_FMT(
"{} = softmax({}, axis = {})", available_inputs_.at(output), available_inputs_.at(input), axis);
646void OnnxDriver::AddNode<NodeOpType::Squeeze>(const ::onnx::NodeProto& node) {
647 DLINEAR_ASSERT(node.op_type() ==
"Squeeze",
"NodeProto must have an op_type of Squeeze");
648 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
649 EnsureInput(node, 1, 2);
651 const std::string& input1 = node.input(0);
652 const std::string& output = node.output(0);
654 if (node.input_size() == 1) {
655 available_inputs_.emplace(output, available_inputs_.at(input1).Squeeze());
656 DLINEAR_DEBUG_FMT(
"Squeeze node: {} = squeeze({})", output, input1);
657 DLINEAR_TRACE_FMT(
"{} = squeeze({})", available_inputs_.at(output), available_inputs_.at(input1));
662 const std::string& input2 = node.input(1);
663 available_inputs_.emplace(output, available_inputs_.at(input1).Squeeze(
664 static_cast<std::vector<std::int64_t>
>(available_inputs_.at(input2))));
665 DLINEAR_DEBUG_FMT(
"Squeeze node: {} = squeeze({}, {})", output, input1, input2);
666 DLINEAR_TRACE_FMT(
"{} = squeeze({}, {})", available_inputs_.at(output), available_inputs_.at(input1),
667 available_inputs_.at(input2));
672void OnnxDriver::AddNode<NodeOpType::Sub>(const ::onnx::NodeProto& node) {
673 DLINEAR_ASSERT(node.op_type() ==
"Sub",
"NodeProto must have an op_type of Sub");
674 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
675 EnsureInput(node, 2);
677 const std::string& input1 = node.input(0);
678 const std::string& input2 = node.input(1);
679 const std::string& output = node.output(0);
680 available_inputs_.emplace(output, available_inputs_.at(input1) - available_inputs_.at(input2));
681 DLINEAR_DEBUG_FMT(
"Sub node: {} = {} - {}", output, input1, input2);
682 DLINEAR_TRACE_FMT(
"{} = {} - {}", available_inputs_.at(output), available_inputs_.at(input1),
683 available_inputs_.at(input2));
688void OnnxDriver::AddNode<NodeOpType::Transpose>(const ::onnx::NodeProto& node) {
689 DLINEAR_ASSERT(node.op_type() ==
"Transpose",
"NodeProto must have an op_type of Transpose");
690 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
691 EnsureInput(node, 1);
693 const std::string& input = node.input(0);
694 const std::string& output = node.output(0);
695 std::vector<std::int64_t> perm{GetAttribute<std::vector<std::int64_t>>(node,
"perm", {{}})};
696 available_inputs_.emplace(output, Tensor{available_inputs_.at(input)}.Transpose(perm));
697 DLINEAR_DEBUG_FMT(
"Transpose {} = {}^T", output, input);
698 DLINEAR_TRACE_FMT(
"{} = {}^T", available_inputs_.at(output), available_inputs_.at(input));
703void OnnxDriver::AddNode<NodeOpType::Unsqueeze>(const ::onnx::NodeProto& node) {
704 DLINEAR_ASSERT(node.op_type() ==
"Unsqueeze",
"NodeProto must have an op_type of Unsqueeze");
705 DLINEAR_ASSERT(node.output_size() == 1,
"NodeProto must have exactly 1 output");
706 EnsureInput(node, 2);
708 const std::string& input1 = node.input(0);
709 const std::string& input2 = node.input(1);
710 const std::string& output = node.output(0);
711 available_inputs_.emplace(output, Tensor{available_inputs_.at(input1)}.Unsqueeze(available_inputs_.at(input2)));
712 DLINEAR_DEBUG_FMT(
"Transpose {} = unsqueeze({}, {})", output, input1, input2);
713 DLINEAR_TRACE_FMT(
"{} = unsqueeze({}, {})", available_inputs_.at(output), available_inputs_.at(input1),
714 available_inputs_.at(input2));
719 DLINEAR_ASSERT(node.has_op_type(),
"NodeProto must have an op_type");
720#ifdef DLINEAR_PYDLINEAR
724 DLINEAR_TRACE_FMT(
"AddNode({})", node.name());
725 if (DLINEAR_TRACE_ENABLED) {
726 for ([[maybe_unused]]
const std::string& input : node.input()) DLINEAR_TRACE_FMT(
"Node input: {}", input);
728 const bool missing_input = std::any_of(node.input().begin(), node.input().end(), [
this](
const std::string& input) {
729 return !available_inputs_.contains(input);
732 DLINEAR_TRACE_FMT(
"Missing input for node {}. Delaying addition", node.name());
738 DLINEAR_RUNTIME_ERROR_FMT(
"Onnx operation {} not currently supported", node.op_type());
740 it->second(*
this, node);
745 std::list<const ::onnx::NodeProto*> nodes;
746 for (const ::onnx::NodeProto& node :
model_.graph().node()) nodes.push_back(&node);
750 for (
auto it = nodes.begin(); it != nodes.end(); it++) {
752 it = nodes.erase(it);
758 if (!nodes.empty()) {
759 DLINEAR_ERROR(
"Failed to add all nodes");
760 if (DLINEAR_TRACE_ENABLED) {
761 for ([[maybe_unused]] const ::onnx::NodeProto* node : nodes)
762 DLINEAR_ERROR_FMT(
"Failed to add node: {}", node->name());
768 DLINEAR_ASSERT(value_info.has_type(),
"ValueInfoProto must have a type");
769 DLINEAR_ASSERT(value_info.has_name(),
"ValueInfoProto must have a name");
770#ifdef DLINEAR_PYDLINEAR
773 DLINEAR_TRACE_FMT(
"AddValueInfo({})", value_info.name());
774 switch (value_info.type().value_case()) {
775 case ::onnx::TypeProto::kTensorType:
779 DLINEAR_UNREACHABLE();
784 DLINEAR_ASSERT(value_info.has_name(),
"ValueInfoProto must have a name");
785 DLINEAR_ASSERT(value_info.type().value_case() == ::onnx::TypeProto::kTensorType,
"ValueInfoProto must be a tensor");
786 DLINEAR_TRACE_FMT(
"AddValueInfoTensor({}, {})", value_info.name(), is_input);
787 const auto [it, res] =
variables_.emplace(value_info.name(),
Tensor(value_info, is_input ?
"X" :
"Y"));
790 DLINEAR_DEBUG_FMT(
"Added variables tensor: {} -> {}", it->first, it->second);
791 if (is_input) DLINEAR_TRACE_FMT(
"Added input: {} -> {}", value_info.name(), it->second);
795 if (is_variable(expr))
return get_variable(expr);
802std::ostream& operator<<(std::ostream& os,
const OnnxDriver& model) {
803 os <<
"OnnxDriver(\n";
804 os <<
"------\nVARIABLES\n------\n";
805 for (
const auto& [name, variables] : model.variables()) {
806 os << name <<
": " << variables <<
"\n";
808 os <<
"------\nINPUTS\n------\n";
809 for (
const auto& [name, values] : model.available_inputs()) {
810 os << name <<
": " << values <<
"\n";
816 {
"Abs", &OnnxDriver::AddNode<NodeOpType::Abs>},
817 {
"Add", &OnnxDriver::AddNode<NodeOpType::Add>},
820 {
"Concat", &OnnxDriver::AddNode<NodeOpType::Concat>},
821 {
"Constant", &OnnxDriver::AddNode<NodeOpType::Constant>},
822 {
"Conv", &OnnxDriver::AddNode<NodeOpType::Conv>},
823 {
"Dropout", &OnnxDriver::AddNode<NodeOpType::Dropout>},
824 {
"Flatten", &OnnxDriver::AddNode<NodeOpType::Flatten>},
825 {
"Gather", &OnnxDriver::AddNode<NodeOpType::Gather>},
826 {
"Gemm", &OnnxDriver::AddNode<NodeOpType::Gemm>},
828 {
"Identity", &OnnxDriver::AddNode<NodeOpType::Identity>},
829 {
"LeakyRelu", &OnnxDriver::AddNode<NodeOpType::LeakyRelu>},
831 {
"MatMul", &OnnxDriver::AddNode<NodeOpType::MatMul>},
833 {
"Mul", &OnnxDriver::AddNode<NodeOpType::Mul>},
834 {
"Relu", &OnnxDriver::AddNode<NodeOpType::Relu>},
835 {
"Reshape", &OnnxDriver::AddNode<NodeOpType::Reshape>},
836 {
"Sign", &OnnxDriver::AddNode<NodeOpType::Sign>},
837 {
"Sigmoid", &OnnxDriver::AddNode<NodeOpType::Sigmoid>},
838 {
"Slice", &OnnxDriver::AddNode<NodeOpType::Slice>},
839 {
"Softmax", &OnnxDriver::AddNode<NodeOpType::Softmax>},
840 {
"Squeeze", &OnnxDriver::AddNode<NodeOpType::Squeeze>},
841 {
"Sub", &OnnxDriver::AddNode<NodeOpType::Sub>},
842 {
"Transpose", &OnnxDriver::AddNode<NodeOpType::Transpose>},
843 {
"Unsqueeze", &OnnxDriver::AddNode<NodeOpType::Unsqueeze>},
860 const std::optional<bool>& default_value)
const;
862 const std::optional<std::string>& default_value)
const;
864 const std::optional<std::int64_t>& default_value)
const;
866 const ::onnx::NodeProto& node,
const std::string& name,
867 const std::optional<std::vector<std::int64_t>>& default_value)
const;
869 const std::optional<float>& default_value)
const;
871 const std::optional<std::vector<float>>& default_value)
const;
873 const ::onnx::NodeProto& node,
const std::string& name,
874 const std::optional<std::vector<std::string>>& default_value)
const;
876 const ::onnx::NodeProto& node,
const std::string& name,
877 const std::optional<const ::onnx::TensorProto*>& default_value)
const;