dlinear  0.0.1
Delta-complete SMT solver for linear programming
Loading...
Searching...
No Matches
Driver.cpp
1
6#include "dlinear/parser/onnx/Driver.h"
7
8#include <bit>
9#include <fstream>
10
11#include "dlinear/parser/onnx/NodeOpType.h"
12#include "dlinear/solver/LeakyReluConstraint.h"
13#include "dlinear/solver/ReluConstraint.h"
14#include "dlinear/util/exception.h"
15
16#ifdef DLINEAR_PYDLINEAR
17#include "pydlinear/interrupt.h"
18#endif
19
20namespace dlinear::onnx {
21
22static_assert(std::endian::native == std::endian::little, "Only little-endian systems are supported for onnx parsing");
23
24namespace {
25inline void invalid_number_of_inputs([[maybe_unused]] const ::onnx::NodeProto& node,
26 [[maybe_unused]] const int actualNumberOfInputs,
27 [[maybe_unused]] const int lowerBound, [[maybe_unused]] const int upperBound) {
28 if (lowerBound == upperBound) {
29 DLINEAR_RUNTIME_ERROR_FMT("Onnx operation '{}' expected to have exactly {} inputs, but found {}", node.op_type(),
30 lowerBound, actualNumberOfInputs);
31 } else {
32 DLINEAR_RUNTIME_ERROR_FMT("Onnx operation '{}' expected to have between {} and {} inputs, but found {}",
33 node.op_type(), lowerBound, upperBound, actualNumberOfInputs);
34 }
35}
36
37inline const ::onnx::AttributeProto* FindAttribute(const ::onnx::NodeProto& node, const std::string& name,
38 ::onnx::AttributeProto_AttributeType expectedType,
39 bool throw_on_missing = false) {
40 for (const ::onnx::AttributeProto& attr : node.attribute()) {
41 if (attr.name() == name) {
42 if (attr.type() != expectedType) {
43 DLINEAR_RUNTIME_ERROR_FMT("Attribute '{}' must be of type {}", name,
44 AttributeProto_AttributeType_Name(expectedType));
45 }
46 return &attr;
47 }
48 }
49 if (throw_on_missing)
50 DLINEAR_RUNTIME_ERROR_FMT("Onnx node of type {} is missing the expected attribute {}", node.op_type(), name);
51 return nullptr;
52}
53
59template <IsAnyOf<bool, float, std::int64_t, std::string, std::vector<float>, std::vector<std::int64_t>,
60 std::vector<std::string>, const ::onnx::TensorProto*>
61 T>
62constexpr ::onnx::AttributeProto_AttributeType GetAttributeType() {
63 if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, std::int64_t>) {
64 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INT;
65 } else if constexpr (std::is_same_v<T, float>) {
66 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT;
67 } else if constexpr (std::is_same_v<T, std::string>) {
68 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING;
69 } else if constexpr (std::is_same_v<T, std::vector<float>>) {
70 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS;
71 } else if constexpr (std::is_same_v<T, std::vector<std::int64_t>>) {
72 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS;
73 } else if constexpr (std::is_same_v<T, std::vector<std::string>>) {
74 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS;
75 } else if constexpr (std::is_same_v<T, const ::onnx::TensorProto*>) {
76 return ::onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR;
77 }
78 DLINEAR_UNREACHABLE();
79}
80
86template <IsAnyOf<bool, float, std::int64_t, std::string, std::vector<float>, std::vector<std::int64_t>,
87 std::vector<std::string>, const ::onnx::TensorProto*>
88 T>
89T GetAttributeValue(const ::onnx::AttributeProto* attr) {
90 DLINEAR_ASSERT(attr != nullptr, "AttributeProto must not be null");
91 if constexpr (std::is_same_v<T, bool>) {
92 return attr->i() != 0;
93 } else if constexpr (std::is_same_v<T, std::int64_t>) {
94 return attr->i();
95 } else if constexpr (std::is_same_v<T, float>) {
96 return attr->f();
97 } else if constexpr (std::is_same_v<T, std::string>) {
98 return attr->s();
99 } else if constexpr (std::is_same_v<T, std::vector<float>>) {
100 return std::vector<float>{attr->floats().begin(), attr->floats().end()};
101 } else if constexpr (std::is_same_v<T, std::vector<std::int64_t>>) {
102 return std::vector<std::int64_t>{attr->ints().begin(), attr->ints().end()};
103 } else if constexpr (std::is_same_v<T, std::vector<std::string>>) {
104 return std::vector<std::string>{attr->strings().begin(), attr->strings().end()};
105 } else if constexpr (std::is_same_v<T, const ::onnx::TensorProto*>) {
106 return &attr->t();
107 }
108 DLINEAR_UNREACHABLE();
109}
110
111} // namespace
112
113OnnxDriver::OnnxDriver(Context& context) : Driver{context, "OnnxDriver"} {}
114
115bool OnnxDriver::ParseStreamCore(std::istream& in) {
116 const bool res = model_.ParseFromIstream(&in);
117 if (!res) {
118 DLINEAR_ERROR("OnnxDriver::ParseStreamCore(): Failed to parse model from input stream");
119 return false;
120 }
121 ParseGraph();
122 return true;
123}
124
125bool OnnxDriver::ParseFile(const std::string& filename) {
126 std::ifstream input(filename, std::ios::binary);
127 if (!input.is_open()) {
128 DLINEAR_ERROR_FMT("OnnxDriver::ParseFile({}): Failed to open file", filename);
129 return false;
130 }
131 return ParseStream(input);
132}
133
135 DLINEAR_TRACE("OnnxDriver::ParseGraph()");
136 if (!model_.has_graph()) DLINEAR_RUNTIME_ERROR("ModelProto must have a graph");
137 std::unordered_set<std::string> initializers;
138 for (const ::onnx::TensorProto& tensor : model_.graph().initializer()) {
139 AddInitializer(tensor);
140 initializers.insert(tensor.name());
141 }
142 for (const ::onnx::ValueInfoProto& input : model_.graph().input()) {
143 if (!available_inputs_.contains(input.name())) AddValueInfo(input, true);
144 }
145 for (const ::onnx::ValueInfoProto& output : model_.graph().output()) AddValueInfo(output);
146 AddNodes();
147 DLINEAR_DEBUG_FMT("OnnxDriver::ParseGraph(): assertions {}", context_.assertions());
148}
149
150template <IsAnyOf<bool, float, std::int64_t, std::string, std::vector<float>, std::vector<std::int64_t>,
151 std::vector<std::string>, const ::onnx::TensorProto*>
152 T>
153T OnnxDriver::GetAttribute(const ::onnx::NodeProto& node, const std::string& name,
154 const std::optional<T>& default_value) const {
155 const ::onnx::AttributeProto* const attr =
156 FindAttribute(node, name, GetAttributeType<T>(), !default_value.has_value());
157 return attr == nullptr ? default_value.value() : GetAttributeValue<T>(attr);
158}
159
160void OnnxDriver::EnsureInput(const ::onnx::NodeProto& node, const int lb, const int ub) {
161 if (node.input_size() < lb || node.input_size() > ub) invalid_number_of_inputs(node, node.input_size(), lb, ub);
162}
163void OnnxDriver::EnsureInput(const ::onnx::NodeProto& node, const int exact) {
164 if (node.input_size() != exact) invalid_number_of_inputs(node, node.input_size(), exact, exact);
165}
166
167void OnnxDriver::AddInitializer(const ::onnx::TensorProto& tensor) {
168 DLINEAR_ASSERT(tensor.has_name(), "TensorProto must have a name");
169 DLINEAR_ASSERT(tensor.has_data_type(), "TensorProto must have a data_type");
170 DLINEAR_TRACE_FMT("AddInitializer({})", tensor.name());
171 available_inputs_.emplace(tensor.name(), tensor);
172}
173
174void OnnxDriver::AddFormula(const std::string& output_name) {
175 if (!variables_.contains(output_name) || !available_inputs_.contains(output_name)) return;
176 DLINEAR_DEBUG_FMT("AddFormula({})", output_name);
177 const Tensor& var_tensor = variables_.at(output_name);
178 const Tensor& final_tensor = available_inputs_.at(output_name);
179 DLINEAR_TRACE_FMT("AddFormula({}): {} == {}", output_name, var_tensor, final_tensor);
180 for (const Formula& f : (var_tensor == final_tensor)) Assert(f);
181 DLINEAR_TRACE_FMT("Added formula for {}. Current assertions: {}", output_name, context_.assertions());
182}
183
184template <>
185void OnnxDriver::AddNode<NodeOpType::Abs>(const ::onnx::NodeProto& node) {
186 DLINEAR_ASSERT(node.op_type() == "Abs", "NodeProto must have an op_type of Abs");
187 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
188 EnsureInput(node, 1);
189
190 const std::string& input = node.input(0);
191 const std::string& output = node.output(0);
192 available_inputs_.emplace(output, Tensor{available_inputs_.at(input)}.Abs());
193 DLINEAR_DEBUG_FMT("Abs node: {} = |{}|", output, input);
194 DLINEAR_TRACE_FMT("{} = |{}|", available_inputs_.at(output), available_inputs_.at(input));
195 AddFormula(output);
196}
197
198template <>
199void OnnxDriver::AddNode<NodeOpType::Add>(const ::onnx::NodeProto& node) {
200 DLINEAR_ASSERT(node.op_type() == "Add", "NodeProto must have an op_type of Add");
201 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
202 EnsureInput(node, 2);
203
204 const std::string& input1 = node.input(0);
205 const std::string& input2 = node.input(1);
206 const std::string& output = node.output(0);
207 available_inputs_.emplace(output, available_inputs_.at(input1) + available_inputs_.at(input2));
208 DLINEAR_DEBUG_FMT("Add node: {} = {} + {}", output, input1, input2);
209 DLINEAR_TRACE_FMT("{} = {} + {}", available_inputs_.at(output), available_inputs_.at(input1),
210 available_inputs_.at(input2));
211 AddFormula(output);
212}
213
214template <>
215void OnnxDriver::AddNode<NodeOpType::Concat>(const ::onnx::NodeProto& node) {
216 DLINEAR_ASSERT(node.op_type() == "Concat", "NodeProto must have an op_type of Concat");
217 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
218 EnsureInput(node, 1, 2147483647);
219
220 const auto axis = GetAttribute<std::int64_t>(node, "axis");
221 const std::string& output = node.output(0);
222 const std::string& input1 = node.input(0);
223 const std::vector<std::string> inputs(node.input().begin() + 1, node.input().end());
224 std::vector<Tensor> tensors;
225 tensors.reserve(inputs.size());
226 std::transform(inputs.begin(), inputs.end(), std::back_inserter(tensors),
227 [this](const std::string& input) { return available_inputs_.at(input); });
228
229 available_inputs_.emplace(output, Tensor{available_inputs_.at(input1)}.Concat(tensors, axis));
230 DLINEAR_DEBUG_FMT("Concat node: {} = concat({})", output, inputs);
231 DLINEAR_TRACE_FMT("{} = concat({})", available_inputs_.at(output), tensors);
232 AddFormula(output);
233}
234
235template <>
236void OnnxDriver::AddNode<NodeOpType::Constant>(const ::onnx::NodeProto& node) {
237 DLINEAR_ASSERT(node.op_type() == "Constant", "NodeProto must have an op_type of Constant");
238 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
239 DLINEAR_ASSERT(node.attribute_size() == 1, "NodeProto must have exactly 1 attribute");
240
241 const std::string& output = node.output(0);
242 const ::onnx::AttributeProto& attr = node.attribute(0);
243 if (attr.has_t()) {
244 available_inputs_.emplace(output, Tensor{attr.t()});
245 } else if (attr.has_f()) {
246 Tensor c{1};
247 c[0] = attr.f();
248 available_inputs_.emplace(output, std::move(c));
249 } else if (attr.has_i()) {
250 Tensor c{1};
251 c[0] = attr.i();
252 available_inputs_.emplace(output, std::move(c));
253 } else {
254 DLINEAR_RUNTIME_ERROR("Constant node must have a tensor, float, or integer attribute");
255 }
256 DLINEAR_DEBUG_FMT("Constant node: {}", output);
257 DLINEAR_TRACE_FMT("{}", available_inputs_.at(output));
258 AddFormula(output);
259}
260
261template <>
262void OnnxDriver::AddNode<NodeOpType::Conv>(const ::onnx::NodeProto& node) {
263 DLINEAR_ASSERT(node.op_type() == "Conv", "NodeProto must have an op_type of Conv");
264 DLINEAR_ASSERT(node.input_size() == 2 || node.input_size() == 3, "NodeProto must have [2-3] inputs");
265 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
266
267 const std::string& input1 = node.input(0);
268 const std::string& input2 = node.input(1);
269
270 const std::string& output = node.output(0);
271 const Tensor& x = available_inputs_.at(input1);
272 const Tensor& w = available_inputs_.at(input2);
273
274 std::string auto_pad{GetAttribute<std::string>(node, "auto_pad", "NOTSET")};
275 std::vector<std::int64_t> dilations{GetAttribute<std::vector<std::int64_t>>(node, "dilations", {{1, 1}})};
276 const auto group = GetAttribute<std::int64_t>(node, "group", 1);
277 std::vector<std::int64_t> kernel_shape{GetAttribute<std::vector<std::int64_t>>(
278 node, "kernel_shape", {{w.values().shape().begin() + 2, w.values().shape().end()}})};
279 std::vector<std::int64_t> pads{GetAttribute<std::vector<std::int64_t>>(node, "pads", {{0, 0, 0, 0}})};
280 std::vector<std::int64_t> strides{GetAttribute<std::vector<std::int64_t>>(node, "strides", {{1, 1}})};
281
282 if (auto_pad != "NOTSET") {
283 pads.clear();
284 pads.assign(2 * strides.size(), 0); // If auto_pad is VALID, we are done
285 if (auto_pad != "VALID") {
286 for (std::size_t i = 0; i < strides.size(); ++i) {
287 const std::int64_t out_dim =
288 (static_cast<std::int64_t>(x.values().shape()[i + 2]) + strides[i] - 1) / strides[i];
289 const std::int64_t fks = kernel_shape[i] / 2;
290 const std::int64_t lks = kernel_shape[i] / 2 - (kernel_shape[i] & 1 ? 0 : 1);
291 const std::int64_t pad = out_dim * strides[i] + fks * dilations[i] + lks * dilations[i] -
292 static_cast<std::int64_t>(x.values().shape()[i + 2]);
293 if (auto_pad == "SAME_LOWER") {
294 pads[i] = pad / 2;
295 pads[i + strides.size()] = pad / 2 + (pad & 1);
296 } else if (auto_pad == "SAME_UPPER") {
297 pads[i] = pad / 2 + (pad & 1);
298 pads[i + strides.size()] = pad / 2;
299 }
300 }
301 }
302 }
303
304 Tensor conv{x.Convolution(w, dilations, group, kernel_shape, pads, strides)};
305 if (node.input_size() > 2) {
306 Tensor& b = available_inputs_.at(node.input(2));
307 b.Reshape({1, static_cast<std::int64_t>(b.size()), 1, 1});
308 conv += b;
309 }
310 available_inputs_.emplace(output, std::move(conv));
311
312 DLINEAR_DEBUG_FMT("Conv node: {} <- conv({}, {}, {}, {}, {}, {}, {}, {})", output, input1, input2, auto_pad,
313 dilations, group, kernel_shape, pads, strides);
314 DLINEAR_TRACE_FMT("{} <- conv({}, {})", available_inputs_.at(output), x, w);
315 AddFormula(output);
316}
317
318template <>
319void OnnxDriver::AddNode<NodeOpType::Dropout>(const ::onnx::NodeProto& node) {
320 DLINEAR_ASSERT(node.op_type() == "Dropout", "NodeProto must have an op_type of Dropout");
321 DLINEAR_ASSERT(node.output_size() == 1 || node.output_size() == 2, "NodeProto must have [1-2] output");
322 EnsureInput(node, 1, 3);
323
324 if (node.input_size() == 3 && available_inputs_.at(node.input(2)).values().size() > 0 &&
325 available_inputs_.at(node.input(2))[0] != 0) {
326 DLINEAR_RUNTIME_ERROR("training_mode must be false in Dropout node");
327 }
328
329 const std::string& input = node.input(0);
330 const std::string& output = node.output(0);
331 available_inputs_.emplace(output, available_inputs_.at(input));
332 DLINEAR_DEBUG_FMT("Dropout node: {} = {}", output, input);
333 DLINEAR_TRACE_FMT("{} = {}", available_inputs_.at(output), available_inputs_.at(input));
334 AddFormula(output);
335}
336
337template <>
338void OnnxDriver::AddNode<NodeOpType::Flatten>(const ::onnx::NodeProto& node) {
339 DLINEAR_ASSERT(node.op_type() == "Flatten", "NodeProto must have an op_type of Flatten");
340 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
341 EnsureInput(node, 1);
342
343 const std::string& input = node.input(0);
344 const std::string& output = node.output(0);
345 const std::int64_t axis = GetAttribute<std::int64_t>(node, "axis", 1);
346 available_inputs_.emplace(output, Tensor{available_inputs_.at(input)}.Flatten(axis));
347 DLINEAR_DEBUG_FMT("Flatten node: {} <- {}", output, input);
348 DLINEAR_TRACE_FMT("{} <- {}", available_inputs_.at(output), available_inputs_.at(input));
349 AddFormula(output);
350}
351
352template <>
353void OnnxDriver::AddNode<NodeOpType::Gather>(const ::onnx::NodeProto& node) {
354 DLINEAR_ASSERT(node.op_type() == "Gather", "NodeProto must have an op_type of Gather");
355 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
356 EnsureInput(node, 2);
357
358 const std::string& input1 = node.input(0);
359 const std::string& input2 = node.input(1);
360 const std::string& output = node.output(0);
361 const std::int64_t axis = GetAttribute<std::int64_t>(node, "axis", 0);
362 available_inputs_.emplace(output, available_inputs_.at(input1).Gather(available_inputs_.at(input2), axis));
363
364 DLINEAR_DEBUG_FMT("Gather node: {} = {}[{}, axis = {}]", output, input1, input2, axis);
365 DLINEAR_TRACE_FMT("{} = {}[{}, axis = {}]", available_inputs_.at(output), available_inputs_.at(input1),
366 available_inputs_.at(input2), axis);
367 AddFormula(output);
368}
369
370template <>
371void OnnxDriver::AddNode<NodeOpType::Gemm>(const ::onnx::NodeProto& node) {
372 DLINEAR_ASSERT(node.op_type() == "Gemm", "NodeProto must have an op_type of Abs");
373 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
374 EnsureInput(node, 2, 3);
375
376 const std::string& input1 = node.input(0);
377 const std::string& input2 = node.input(1);
378 const std::string& output = node.output(0);
379 const float alpha = GetAttribute<float>(node, "alpha", 1);
380 const bool transA = GetAttribute<bool>(node, "transA", false);
381 const bool transB = GetAttribute<bool>(node, "transB", false);
382
383 Tensor A{available_inputs_.at(input1)};
384 if (transA) A.Transpose();
385 Tensor B{available_inputs_.at(input2)};
386 if (transB) B.Transpose();
387 Tensor gemm = A.MatMul(B) * alpha;
388
389 if (node.attribute_size() == 2) {
390 DLINEAR_DEBUG_FMT("Gemm node: {} = {} * {} x {}", output, alpha, input1, input2);
391 DLINEAR_TRACE_FMT("{} = {} * {} x {}", gemm, alpha, available_inputs_.at(input1), available_inputs_.at(input2));
392 }
393
394 if (node.input_size() == 3) {
395 const auto beta = GetAttribute<float>(node, "beta", 1);
396 const std::string& input3 = node.input(2);
397 gemm += available_inputs_.at(input3) * beta;
398 DLINEAR_DEBUG_FMT("Gemm node: {} = {} * {} x {} + {} * {}", output, alpha, input1, input2, beta, input3);
399 DLINEAR_TRACE_FMT("{} = {} * {} x {} + {} * {}", gemm, alpha, available_inputs_.at(input1),
400 available_inputs_.at(input2), beta, available_inputs_.at(input3));
401 }
402
403 available_inputs_.emplace(output, gemm);
404
405 AddFormula(output);
406}
407
408template <>
409void OnnxDriver::AddNode<NodeOpType::Identity>(const ::onnx::NodeProto& node) {
410 DLINEAR_ASSERT(node.op_type() == "Identity", "NodeProto must have an op_type of Identity");
411 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
412 EnsureInput(node, 1);
413
414 const std::string& input = node.input(0);
415 const std::string& output = node.output(0);
416 available_inputs_.emplace(output, available_inputs_.at(input));
417 DLINEAR_DEBUG_FMT("Identity node: {} = {}", output, input);
418 DLINEAR_TRACE_FMT("{} = {}", available_inputs_.at(output), available_inputs_.at(input));
419 AddFormula(output);
420}
421
422template <>
423void OnnxDriver::AddNode<NodeOpType::LeakyRelu>(const ::onnx::NodeProto& node) {
424 DLINEAR_ASSERT(node.op_type() == "LeakyRelu", "NodeProto must have an op_type of LeakyRelu");
425 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
426 EnsureInput(node, 1);
427 const float alpha = GetAttribute<float>(node, "alpha", 0.01);
428
429 const std::string& input = node.input(0);
430 const std::string& output = node.output(0);
431 Tensor relu = Tensor{available_inputs_.at(input)};
432
433#ifndef NDEBUG
434 std::size_t counter = 0;
435 relu.Elementwise([alpha, &counter, &input, this](const Expression& e) {
436#else
437 relu.Elementwise([alpha, this](const Expression& e) {
438#endif
439 const Formula condition{e > 0};
440 // Trivial cases for the ReLU function
441 if (is_true(condition)) {
442 return e;
443 } else if (is_false(condition)) {
444 return alpha * e;
445 }
446#ifndef NDEBUG
447 const Variable relu_var{fmt::format("{}_leaky_relu_{}", input, ++counter)};
448#else
449 const Variable relu_var{"lr"};
450#endif
451 // Adding the fresh ITE variable as a guided constraint
452 context_.AssertPiecewiseLinearFunction(relu_var, e >= 0, e, alpha * e);
453 context_.AddGuidedConstraint(
454 std::make_unique<LeakyReluConstraint>(relu_var, e, alpha, context_.predicate_abstractor()));
455 return Expression{relu_var};
456 });
457 available_inputs_.emplace(output, relu);
458 DLINEAR_DEBUG_FMT("Relu node: {} = 0 if input < 0 else {} * {}", output, alpha, input);
459 DLINEAR_TRACE_FMT("{}", relu);
460 AddFormula(output);
461}
462
463template <>
464void OnnxDriver::AddNode<NodeOpType::MatMul>(const ::onnx::NodeProto& node) {
465 DLINEAR_ASSERT(node.op_type() == "MatMul", "NodeProto must have an op_type of MatMul");
466 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
467 EnsureInput(node, 2);
468
469 const std::string& input1 = node.input(0);
470 const std::string& input2 = node.input(1);
471 const std::string& output = node.output(0);
472 available_inputs_.emplace(output, available_inputs_.at(input1).MatMul(available_inputs_.at(input2)));
473 DLINEAR_DEBUG_FMT("MatMul node: {} = {} x {}", output, input1, input2);
474 DLINEAR_TRACE_FMT("{} = {} x {}", available_inputs_.at(output), available_inputs_.at(input1),
475 available_inputs_.at(input2));
476 AddFormula(output);
477}
478
479template <>
480void OnnxDriver::AddNode<NodeOpType::Mul>(const ::onnx::NodeProto& node) {
481 DLINEAR_ASSERT(node.op_type() == "Mul", "NodeProto must have an op_type of Mul");
482 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
483 EnsureInput(node, 2);
484
485 const std::string& input1 = node.input(0);
486 const std::string& input2 = node.input(1);
487 const std::string& output = node.output(0);
488 available_inputs_.emplace(output, available_inputs_.at(input1) * available_inputs_.at(input2));
489 DLINEAR_DEBUG_FMT("Mul node: {} = {} * {}", output, input1, input2);
490 DLINEAR_TRACE_FMT("{} = {} * {}", available_inputs_.at(output), available_inputs_.at(input1),
491 available_inputs_.at(input2));
492 AddFormula(output);
493}
494
495template <>
496void OnnxDriver::AddNode<NodeOpType::Reshape>(const ::onnx::NodeProto& node) {
497 DLINEAR_ASSERT(node.op_type() == "Reshape", "NodeProto must have an op_type of Reshape");
498 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
499 EnsureInput(node, 2);
500
501 const std::string& input1 = node.input(0);
502 const std::string& input2 = node.input(1);
503 const std::string& output = node.output(0);
504 const bool allow_zero = GetAttribute<bool>(node, "allowzero", false);
505
506 const Tensor& shape = available_inputs_.at(input2);
507 available_inputs_.emplace(output, Tensor{available_inputs_.at(input1)}.Reshape(shape, allow_zero));
508 DLINEAR_DEBUG_FMT("Reshape node: {} = reshape({}, {})", output, input1, input2);
509 DLINEAR_TRACE_FMT("{} = reshape({}, {})", available_inputs_.at(output), available_inputs_.at(input1),
510 available_inputs_.at(input2));
511 AddFormula(output);
512}
513
514template <>
515void OnnxDriver::AddNode<NodeOpType::Relu>(const ::onnx::NodeProto& node) {
516 DLINEAR_ASSERT(node.op_type() == "Relu", "NodeProto must have an op_type of Relu");
517 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
518 EnsureInput(node, 1);
519
520 const std::string& input = node.input(0);
521 const std::string& output = node.output(0);
522 Tensor relu = Tensor{available_inputs_.at(input)};
523
524#ifndef NDEBUG
525 std::size_t counter = 0;
526 relu.Elementwise([&counter, &input, this](const Expression& e) {
527#else
528 relu.Elementwise([this](const Expression& e) {
529#endif
530 const Formula condition{e > 0};
531 // Trivial cases for the ReLU function
532 if (is_true(condition)) {
533 return e;
534 } else if (is_false(condition)) {
535 return Expression{0};
536 }
537#ifndef NDEBUG
538 const Variable relu_var{fmt::format("{}_relu_{}", input, ++counter)};
539#else
540 const Variable relu_var{"r"};
541#endif
542 // Adding the fresh ITE variable as a guided constraint
543 context_.AssertPiecewiseLinearFunction(relu_var, e >= 0, e, 0);
544 context_.Assert(relu_var >= 0);
545 context_.AddGuidedConstraint(std::make_unique<ReluConstraint>(relu_var, e, context_.predicate_abstractor()));
546 return Expression{relu_var};
547 });
548 available_inputs_.emplace(output, relu);
549 DLINEAR_DEBUG_FMT("Relu node: {} = 0 if input < 0 else {}", output, input);
550 DLINEAR_TRACE_FMT("{}", relu);
551 AddFormula(output);
552}
553
554template <>
555void OnnxDriver::AddNode<NodeOpType::Sign>(const ::onnx::NodeProto& node) {
556 DLINEAR_ASSERT(node.op_type() == "Sign", "NodeProto must have an op_type of Sign");
557 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
558 EnsureInput(node, 1);
559
560 const std::string& input = node.input(0);
561 const std::string& output = node.output(0);
562 Tensor sign = Tensor{available_inputs_.at(input)};
563
564 sign.Elementwise([](const Expression& e) { return if_then_else(e == 0, 0, if_then_else(e >= 0, 1, -1)); });
565 available_inputs_.emplace(output, sign);
566 DLINEAR_DEBUG_FMT("Sign node: {} = Sign({})", output, input);
567 DLINEAR_TRACE_FMT("{}", sign);
568 AddFormula(output);
569}
570
571template <>
572void OnnxDriver::AddNode<NodeOpType::Sigmoid>(const ::onnx::NodeProto& node) {
573 DLINEAR_ASSERT(node.op_type() == "Sigmoid", "NodeProto must have an op_type of Sigmoid");
574 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
575 EnsureInput(node, 1);
576
577 const std::string& input = node.input(0);
578 const std::string& output = node.output(0);
579 Tensor relu = Tensor{available_inputs_.at(input)};
580
581 relu.Elementwise([](const Expression& e) {
582 if (!is_constant(e)) DLINEAR_RUNTIME_ERROR("Cannot apply the sigmoid function to a non constant value");
583 return 1 / (1 + exp(-get_constant_value(e).get_d()));
584 });
585 available_inputs_.emplace(output, relu);
586 DLINEAR_DEBUG_FMT("Relu node: {} = 0 if input < 0 else {}", output, input);
587 DLINEAR_TRACE_FMT("{}", relu);
588 AddFormula(output);
589}
590
591template <>
592void OnnxDriver::AddNode<NodeOpType::Slice>(const ::onnx::NodeProto& node) {
593 DLINEAR_ASSERT(node.op_type() == "Slice", "NodeProto must have an op_type of Slice");
594 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
595 EnsureInput(node, 3, 5);
596
597 const std::string& data = node.input(0);
598 const std::string& starts = node.input(1);
599 const std::string& ends = node.input(2);
600 const std::string& axis = node.input_size() > 3 ? node.input(3) : "";
601 const std::string& steps = node.input_size() > 4 ? node.input(4) : "";
602 const std::string& output = node.output(0);
603 const std::vector<std::int64_t> starts_v = static_cast<std::vector<std::int64_t>>(available_inputs_.at(starts));
604 const std::vector<std::int64_t> ends_v = static_cast<std::vector<std::int64_t>>(available_inputs_.at(ends));
605 const std::vector<std::int64_t> axis_v =
606 axis.empty() ? std::vector<std::int64_t>{} : static_cast<std::vector<std::int64_t>>(available_inputs_.at(axis));
607 const std::vector<std::int64_t> steps_v =
608 steps.empty() ? std::vector<std::int64_t>{} : static_cast<std::vector<std::int64_t>>(available_inputs_.at(steps));
609 available_inputs_.emplace(output, available_inputs_.at(data).Slice(starts_v, ends_v, axis_v, steps_v));
610
611 DLINEAR_DEBUG_FMT("Slice node: {} = {}[{}:{}:{}:{}]", output, data, starts, ends, axis, steps);
612 DLINEAR_TRACE_FMT("{} = {}[{}:{}:{}:{}", available_inputs_.at(output), available_inputs_.at(data), starts_v, ends_v,
613 axis_v, steps_v);
614
615 AddFormula(output);
616}
617
618template <>
619void OnnxDriver::AddNode<NodeOpType::Softmax>(const ::onnx::NodeProto& node) {
620 DLINEAR_ASSERT(node.op_type() == "Softmax", "NodeProto must have an op_type of Softmax");
621 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
622 EnsureInput(node, 1);
623
624 const std::string& input = node.input(0);
625 const std::string& output = node.output(0);
626
627 const Expression& max = available_inputs_.at(input).Max();
628 const xt::xarray<Expression> softmax_values{xt::exp((available_inputs_.at(input) - max).values())};
629 const std::int64_t axis = GetAttribute<std::int64_t>(node, "axis", -1);
630
631 DLINEAR_ASSERT(std::for_each(available_inputs_.at(input).begin(), available_inputs_.at(input).end(),
632 [](const Expression& e) { return is_constant(e); }),
633 "Softmax input must be constant");
634
635 xt::xarray<Expression> sum{xt::sum(softmax_values, axis)};
636 auto shape = available_inputs_.at(input).values().shape();
637 shape.at(axis < 0 ? shape.size() + axis : axis) = 1;
638 sum.reshape(shape);
639 available_inputs_.emplace(output, softmax_values / sum);
640 DLINEAR_DEBUG_FMT("Softmax node: {} = softmax({}, axis = {})", output, input, axis);
641 DLINEAR_TRACE_FMT("{} = softmax({}, axis = {})", available_inputs_.at(output), available_inputs_.at(input), axis);
642 AddFormula(output);
643}
644
645template <>
646void OnnxDriver::AddNode<NodeOpType::Squeeze>(const ::onnx::NodeProto& node) {
647 DLINEAR_ASSERT(node.op_type() == "Squeeze", "NodeProto must have an op_type of Squeeze");
648 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
649 EnsureInput(node, 1, 2);
650
651 const std::string& input1 = node.input(0);
652 const std::string& output = node.output(0);
653
654 if (node.input_size() == 1) {
655 available_inputs_.emplace(output, available_inputs_.at(input1).Squeeze());
656 DLINEAR_DEBUG_FMT("Squeeze node: {} = squeeze({})", output, input1);
657 DLINEAR_TRACE_FMT("{} = squeeze({})", available_inputs_.at(output), available_inputs_.at(input1));
658 AddFormula(output);
659 return;
660 }
661
662 const std::string& input2 = node.input(1);
663 available_inputs_.emplace(output, available_inputs_.at(input1).Squeeze(
664 static_cast<std::vector<std::int64_t>>(available_inputs_.at(input2))));
665 DLINEAR_DEBUG_FMT("Squeeze node: {} = squeeze({}, {})", output, input1, input2);
666 DLINEAR_TRACE_FMT("{} = squeeze({}, {})", available_inputs_.at(output), available_inputs_.at(input1),
667 available_inputs_.at(input2));
668 AddFormula(output);
669}
670
671template <>
672void OnnxDriver::AddNode<NodeOpType::Sub>(const ::onnx::NodeProto& node) {
673 DLINEAR_ASSERT(node.op_type() == "Sub", "NodeProto must have an op_type of Sub");
674 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
675 EnsureInput(node, 2);
676
677 const std::string& input1 = node.input(0);
678 const std::string& input2 = node.input(1);
679 const std::string& output = node.output(0);
680 available_inputs_.emplace(output, available_inputs_.at(input1) - available_inputs_.at(input2));
681 DLINEAR_DEBUG_FMT("Sub node: {} = {} - {}", output, input1, input2);
682 DLINEAR_TRACE_FMT("{} = {} - {}", available_inputs_.at(output), available_inputs_.at(input1),
683 available_inputs_.at(input2));
684 AddFormula(output);
685}
686
687template <>
688void OnnxDriver::AddNode<NodeOpType::Transpose>(const ::onnx::NodeProto& node) {
689 DLINEAR_ASSERT(node.op_type() == "Transpose", "NodeProto must have an op_type of Transpose");
690 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
691 EnsureInput(node, 1);
692
693 const std::string& input = node.input(0);
694 const std::string& output = node.output(0);
695 std::vector<std::int64_t> perm{GetAttribute<std::vector<std::int64_t>>(node, "perm", {{}})};
696 available_inputs_.emplace(output, Tensor{available_inputs_.at(input)}.Transpose(perm));
697 DLINEAR_DEBUG_FMT("Transpose {} = {}^T", output, input);
698 DLINEAR_TRACE_FMT("{} = {}^T", available_inputs_.at(output), available_inputs_.at(input));
699 AddFormula(output);
700}
701
702template <>
703void OnnxDriver::AddNode<NodeOpType::Unsqueeze>(const ::onnx::NodeProto& node) {
704 DLINEAR_ASSERT(node.op_type() == "Unsqueeze", "NodeProto must have an op_type of Unsqueeze");
705 DLINEAR_ASSERT(node.output_size() == 1, "NodeProto must have exactly 1 output");
706 EnsureInput(node, 2);
707
708 const std::string& input1 = node.input(0);
709 const std::string& input2 = node.input(1);
710 const std::string& output = node.output(0);
711 available_inputs_.emplace(output, Tensor{available_inputs_.at(input1)}.Unsqueeze(available_inputs_.at(input2)));
712 DLINEAR_DEBUG_FMT("Transpose {} = unsqueeze({}, {})", output, input1, input2);
713 DLINEAR_TRACE_FMT("{} = unsqueeze({}, {})", available_inputs_.at(output), available_inputs_.at(input1),
714 available_inputs_.at(input2));
715 AddFormula(output);
716}
717
718bool OnnxDriver::AddNode(const ::onnx::NodeProto& node) {
719 DLINEAR_ASSERT(node.has_op_type(), "NodeProto must have an op_type");
720#ifdef DLINEAR_PYDLINEAR
721 py_check_signals();
722#endif
723
724 DLINEAR_TRACE_FMT("AddNode({})", node.name());
725 if (DLINEAR_TRACE_ENABLED) {
726 for ([[maybe_unused]] const std::string& input : node.input()) DLINEAR_TRACE_FMT("Node input: {}", input);
727 }
728 const bool missing_input = std::any_of(node.input().begin(), node.input().end(), [this](const std::string& input) {
729 return !available_inputs_.contains(input);
730 });
731 if (missing_input) {
732 DLINEAR_TRACE_FMT("Missing input for node {}. Delaying addition", node.name());
733 return false;
734 }
735
736 const auto it = node_handlers.find(node.op_type());
737 if (it == node_handlers.end()) {
738 DLINEAR_RUNTIME_ERROR_FMT("Onnx operation {} not currently supported", node.op_type());
739 }
740 it->second(*this, node);
741 return true;
742}
743
745 std::list<const ::onnx::NodeProto*> nodes;
746 for (const ::onnx::NodeProto& node : model_.graph().node()) nodes.push_back(&node);
747 bool added = true;
748 while (added) {
749 added = false;
750 for (auto it = nodes.begin(); it != nodes.end(); it++) {
751 if (AddNode(**it)) {
752 it = nodes.erase(it);
753 it--;
754 added = true;
755 }
756 }
757 }
758 if (!nodes.empty()) {
759 DLINEAR_ERROR("Failed to add all nodes");
760 if (DLINEAR_TRACE_ENABLED) {
761 for ([[maybe_unused]] const ::onnx::NodeProto* node : nodes)
762 DLINEAR_ERROR_FMT("Failed to add node: {}", node->name());
763 }
764 }
765}
766
767void OnnxDriver::AddValueInfo(const ::onnx::ValueInfoProto& value_info, const bool is_input) {
768 DLINEAR_ASSERT(value_info.has_type(), "ValueInfoProto must have a type");
769 DLINEAR_ASSERT(value_info.has_name(), "ValueInfoProto must have a name");
770#ifdef DLINEAR_PYDLINEAR
771 py_check_signals();
772#endif
773 DLINEAR_TRACE_FMT("AddValueInfo({})", value_info.name());
774 switch (value_info.type().value_case()) {
775 case ::onnx::TypeProto::kTensorType:
776 AddValueInfoTensor(value_info, is_input);
777 break;
778 default:
779 DLINEAR_UNREACHABLE();
780 }
781}
782
783void OnnxDriver::AddValueInfoTensor(const ::onnx::ValueInfoProto& value_info, const bool is_input) {
784 DLINEAR_ASSERT(value_info.has_name(), "ValueInfoProto must have a name");
785 DLINEAR_ASSERT(value_info.type().value_case() == ::onnx::TypeProto::kTensorType, "ValueInfoProto must be a tensor");
786 DLINEAR_TRACE_FMT("AddValueInfoTensor({}, {})", value_info.name(), is_input);
787 const auto [it, res] = variables_.emplace(value_info.name(), Tensor(value_info, is_input ? "X" : "Y"));
788 if (is_input) available_inputs_.emplace(value_info.name(), it->second);
789 for (const auto& exp : it->second) context_.DeclareVariable(get_variable(exp));
790 DLINEAR_DEBUG_FMT("Added variables tensor: {} -> {}", it->first, it->second);
791 if (is_input) DLINEAR_TRACE_FMT("Added input: {} -> {}", value_info.name(), it->second);
792}
793
795 if (is_variable(expr)) return get_variable(expr);
796 auto it = equal_vars_.find(expr);
797 if (it != equal_vars_.end()) return it->second;
798 const auto [ins, _] = equal_vars_.emplace(expr, Variable{"var/" + expr.to_string()});
799 return ins->second;
800}
801
802std::ostream& operator<<(std::ostream& os, const OnnxDriver& model) {
803 os << "OnnxDriver(\n";
804 os << "------\nVARIABLES\n------\n";
805 for (const auto& [name, variables] : model.variables()) {
806 os << name << ": " << variables << "\n";
807 }
808 os << "------\nINPUTS\n------\n";
809 for (const auto& [name, values] : model.available_inputs()) {
810 os << name << ": " << values << "\n";
811 }
812 return os << ")";
813}
814
815const std::map<std::string, std::function<void(OnnxDriver&, const ::onnx::NodeProto&)>> OnnxDriver::node_handlers{
816 {"Abs", &OnnxDriver::AddNode<NodeOpType::Abs>},
817 {"Add", &OnnxDriver::AddNode<NodeOpType::Add>},
818 // {"AveragePool", &OnnxDriver::AddNode<NodeOpType::AveragePool>},
819 // {"BatchNormalization", &OnnxDriver::AddNode<NodeOpType::BatchNormalization>},
820 {"Concat", &OnnxDriver::AddNode<NodeOpType::Concat>},
821 {"Constant", &OnnxDriver::AddNode<NodeOpType::Constant>},
822 {"Conv", &OnnxDriver::AddNode<NodeOpType::Conv>},
823 {"Dropout", &OnnxDriver::AddNode<NodeOpType::Dropout>},
824 {"Flatten", &OnnxDriver::AddNode<NodeOpType::Flatten>},
825 {"Gather", &OnnxDriver::AddNode<NodeOpType::Gather>},
826 {"Gemm", &OnnxDriver::AddNode<NodeOpType::Gemm>},
827 // {"GlobalAveragePool", &OnnxDriver::AddNode<NodeOpType::GlobalAveragePool>},
828 {"Identity", &OnnxDriver::AddNode<NodeOpType::Identity>},
829 {"LeakyRelu", &OnnxDriver::AddNode<NodeOpType::LeakyRelu>},
830 // {"LRN", &OnnxDriver::AddNode<NodeOpType::LRN>},
831 {"MatMul", &OnnxDriver::AddNode<NodeOpType::MatMul>},
832 // {"MaxPool", &OnnxDriver::AddNode<NodeOpType::MaxPool>},
833 {"Mul", &OnnxDriver::AddNode<NodeOpType::Mul>},
834 {"Relu", &OnnxDriver::AddNode<NodeOpType::Relu>},
835 {"Reshape", &OnnxDriver::AddNode<NodeOpType::Reshape>},
836 {"Sign", &OnnxDriver::AddNode<NodeOpType::Sign>},
837 {"Sigmoid", &OnnxDriver::AddNode<NodeOpType::Sigmoid>},
838 {"Slice", &OnnxDriver::AddNode<NodeOpType::Slice>},
839 {"Softmax", &OnnxDriver::AddNode<NodeOpType::Softmax>},
840 {"Squeeze", &OnnxDriver::AddNode<NodeOpType::Squeeze>},
841 {"Sub", &OnnxDriver::AddNode<NodeOpType::Sub>},
842 {"Transpose", &OnnxDriver::AddNode<NodeOpType::Transpose>},
843 {"Unsqueeze", &OnnxDriver::AddNode<NodeOpType::Unsqueeze>},
844};
845
846// else if (strcmp(nodeType, "BatchNormalization") == 0) {
847// batchNormEquations(node, makeEquations);
848// }
849// else if (strcmp(nodeType, "MaxPool") == 0) {
850// maxPoolEquations(node, makeEquations);
851// }
852// else if (strcmp(nodeType, "LeakyRelu") == 0) {
853// leakyReluEquations(node, makeEquations);
854// }
855// else if (strcmp(nodeType, "Tanh") == 0) {
856// tanhEquations(node, makeEquations);
857// }
858
859template bool OnnxDriver::GetAttribute(const ::onnx::NodeProto& node, const std::string& name,
860 const std::optional<bool>& default_value) const;
861template std::string OnnxDriver::GetAttribute(const ::onnx::NodeProto& node, const std::string& name,
862 const std::optional<std::string>& default_value) const;
863template std::int64_t OnnxDriver::GetAttribute(const ::onnx::NodeProto& node, const std::string& name,
864 const std::optional<std::int64_t>& default_value) const;
865template std::vector<std::int64_t> OnnxDriver::GetAttribute(
866 const ::onnx::NodeProto& node, const std::string& name,
867 const std::optional<std::vector<std::int64_t>>& default_value) const;
868template float OnnxDriver::GetAttribute(const ::onnx::NodeProto& node, const std::string& name,
869 const std::optional<float>& default_value) const;
870template std::vector<float> OnnxDriver::GetAttribute(const ::onnx::NodeProto& node, const std::string& name,
871 const std::optional<std::vector<float>>& default_value) const;
872template std::vector<std::string> OnnxDriver::GetAttribute(
873 const ::onnx::NodeProto& node, const std::string& name,
874 const std::optional<std::vector<std::string>>& default_value) const;
875template const ::onnx::TensorProto* OnnxDriver::GetAttribute(
876 const ::onnx::NodeProto& node, const std::string& name,
877 const std::optional<const ::onnx::TensorProto*>& default_value) const;
878
879} // namespace dlinear::onnx
Context class that holds the set of constraints and provide Assert/Push/Pop/CheckSat functionalities.
Definition Context.h:31
void DeclareVariable(const Variable &v, bool is_model_variable=true)
Declare a variable v.
Definition Context.cpp:31
const ScopedVector< Formula > & assertions() const
Get the the asserted formulas.
Definition Context.cpp:71
The Driver is the base class for all the parsers.
Definition Driver.h:26
Context & context_
The context filled during parsing of the expressions.
Definition Driver.h:166
bool ParseStream(std::istream &in, const std::string &sname="stream input")
Invoke the scanner and parser for a stream.
Definition Driver.cpp:30
Represents a symbolic form of an expression.
std::string to_string() const
Returns string representation of Expression.
Represents a symbolic form of a first-order logic formula.
Represents a symbolic variable.
The OnnxDriver class reads the onnx file.
Definition Driver.h:37
void AddNodes()
Go through all the nodes in the graph and construct the final assertions.
Definition Driver.cpp:744
static const std::map< std::string, std::function< void(OnnxDriver &, const ::onnx::NodeProto &)> > node_handlers
Map between node op_type and the corresponding handler..
Definition Driver.h:59
bool ParseFile(const std::string &filename) override
Invoke the scanner and parser on a file.
Definition Driver.cpp:125
T GetAttribute(const ::onnx::NodeProto &node, const std::string &name, const std::optional< T > &default_value={}) const
Get the attribute name from the node.
Definition Driver.cpp:153
const Variable & ToEqualVar(const Expression &expression)
Associate to a linear expression a fresh variable.
Definition Driver.cpp:794
bool AddNode(const ::onnx::NodeProto &node)
Go through a specific node and add the corresponding assertions.
Definition Driver.cpp:718
std::unordered_map< Expression, Variable > equal_vars_
Variables created to summarize linear constraints.
Definition Driver.h:156
void AddValueInfo(const ::onnx::ValueInfoProto &value_info, bool is_input=false)
Add the input and output to the Context.
Definition Driver.cpp:767
bool ParseStreamCore(std::istream &in) override
Parse the stream.
Definition Driver.cpp:115
void AddFormula(const std::string &output)
Add the formulas to the Context.
Definition Driver.cpp:174
std::unordered_map< std::string, Tensor > available_inputs_
Available inputs in the model.
Definition Driver.h:155
::onnx::ModelProto model_
The onnx model obtained from the file.
Definition Driver.h:153
OnnxDriver(Context &context)
Construct a new parser driver context.
Definition Driver.cpp:113
void AddInitializer(const ::onnx::TensorProto &tensor)
Add an initializer to the available_inputs_.
Definition Driver.cpp:167
std::unordered_map< std::string, Tensor > variables_
Variables in the model.
Definition Driver.h:154
void ParseGraph()
Parse the model_ 's graph.
Definition Driver.cpp:134
void AddValueInfoTensor(const ::onnx::ValueInfoProto &value_info, bool is_input=false)
Add the input and output to the Context.
Definition Driver.cpp:783
Tensor & Abs()
Apply the Abs function to the tensor.
Definition Tensor.cpp:216
Namespace for the ONNX parser of the dlinear library.
Definition Driver.cpp:20
@ B
Both upper and lower bound are equal (fixed)