6#include "dlinear/parser/onnx/Tensor.h"
16#include "dlinear/symbolic/literal.h"
17#include "dlinear/util/exception.h"
18#include "dlinear/util/logging.h"
24inline std::vector<int64_t> get_dims(const ::onnx::ValueInfoProto &value_info) {
25 DLINEAR_ASSERT(value_info.has_type(),
"ValueInfoProto must have type");
26 DLINEAR_ASSERT(value_info.type().has_tensor_type(),
"ValueInfoProto must have tensor_type");
27 DLINEAR_ASSERT(value_info.type().tensor_type().has_shape(),
"ValueInfoProto must have shape");
28 std::vector<int64_t> dims;
29 dims.reserve(value_info.type().tensor_type().shape().dim_size());
30 for (const ::onnx::TensorShapeProto_Dimension &dim : value_info.type().tensor_type().shape().dim()) {
31 if (dim.has_dim_value()) {
32 dims.push_back(dim.dim_value());
33 }
else if (dim.has_dim_param()) {
34 DLINEAR_WARN_FMT(
"Parametric dimension {} is being set to 1", dim.dim_param());
37 DLINEAR_UNREACHABLE();
43inline std::vector<int64_t> get_dims(const ::onnx::TensorProto &tensor) {
44 if (tensor.dims_size() == 0)
return {1};
45 std::vector<int64_t> dims;
46 dims.reserve(tensor.dims_size());
47 for (
const std::int64_t dim : tensor.dims()) {
48 DLINEAR_ASSERT(dim > 0,
"All dimensions of a tensor must be >= 1");
66Tensor::Tensor(const ::onnx::ValueInfoProto &value_info,
const std::string &name)
67 : values_{
xt::xarray<
Expression>::from_shape(get_dims(value_info))} {
73 DLINEAR_ASSERT(tensor.has_data_type(),
"TensorProto must have a data_type");
75 const void *
const raw_data = tensor.has_raw_data() ? tensor.raw_data().data() :
nullptr;
78 switch (tensor.data_type()) {
79 case ::onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
80 for (
int i = 0; i <
size; ++i) {
81 values_.flat(i) = raw_data ==
nullptr ? tensor.float_data(i) :
static_cast<const float *
>(raw_data)[i];
84 case ::onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
85 for (
int i = 0; i <
size; ++i) {
86 values_.flat(i) = raw_data ==
nullptr ? tensor.double_data(i) :
static_cast<const double *
>(raw_data)[i];
89 case ::onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
90 for (
int i = 0; i <
size; ++i) {
91 values_.flat(i) = raw_data ==
nullptr ? tensor.uint64_data(i) :
static_cast<const uint64_t *
>(raw_data)[i];
94 case ::onnx::TensorProto_DataType::TensorProto_DataType_INT64:
95 for (
int i = 0; i <
size; ++i) {
96 values_.flat(i) = raw_data ==
nullptr ? tensor.int64_data(i) :
static_cast<const int64_t *
>(raw_data)[i];
99 case ::onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
100 for (
int i = 0; i <
size; ++i) {
101 values_.flat(i) = raw_data ==
nullptr ? tensor.int32_data(i) :
static_cast<const int32_t *
>(raw_data)[i];
104 case ::onnx::TensorProto_DataType::TensorProto_DataType_INT8:
105 DLINEAR_ASSERT(raw_data !=
nullptr,
"Raw data must be provided for int8 data type");
106 for (
int i = 0; i <
size; ++i) {
107 values_.flat(i) =
static_cast<const int8_t *
>(raw_data)[i];
110 case ::onnx::TensorProto_DataType::TensorProto_DataType_INT16:
111 DLINEAR_ASSERT(raw_data !=
nullptr,
"Raw data must be provided for int16 data type");
112 for (
int i = 0; i <
size; ++i) {
113 values_.flat(i) =
static_cast<const int16_t *
>(raw_data)[i];
116 case ::onnx::TensorProto_DataType::TensorProto_DataType_INT32:
117 DLINEAR_ASSERT(raw_data !=
nullptr,
"Raw data must be provided for int32 data type");
118 for (
int i = 0; i <
size; ++i) {
119 values_.flat(i) =
static_cast<const int32_t *
>(raw_data)[i];
122 case ::onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
123 DLINEAR_ASSERT(raw_data !=
nullptr,
"Raw data must be provided for uint8 data type");
124 for (
int i = 0; i <
size; ++i) {
125 values_.flat(i) =
static_cast<const uint8_t *
>(raw_data)[i];
128 case ::onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
129 DLINEAR_ASSERT(raw_data !=
nullptr,
"Raw data must be provided for uint32 data type");
130 for (
int i = 0; i <
size; ++i) {
131 values_.flat(i) =
static_cast<const uint32_t *
>(raw_data)[i];
134 case ::onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
136 DLINEAR_RUNTIME_ERROR_FMT(
"Unsupported data type: {}", tensor.data_type());
143 if (i >=
values_.dimension())
return 1;
144 return static_cast<std::int64_t
>(
values_.shape(i));
154 for (std::size_t i = 0; i <
values_.size(); i++) {
161 if (axis < 0 || axis >=
static_cast<std::int64_t
>(
values_.size()))
162 DLINEAR_OUT_OF_RANGE_FMT(
"Invalid axis. Must be in [{}, {}]", 0,
values_.size());
163 const std::int64_t rows =
164 std::reduce(
values_.shape().cbegin(),
values_.shape().cbegin() + axis, 1, std::multiplies<std::int64_t>{});
165 const std::int64_t cols =
166 std::reduce(
values_.shape().cbegin() + axis,
values_.shape().cend(), 1, std::multiplies<std::int64_t>{});
172 std::vector<std::size_t> new_shape(
values_.shape().size() + axes.
size(), 0);
173 for (
const std::int64_t axes_value :
static_cast<std::vector<std::int64_t>
>(axes)) new_shape.at(axes_value) = 1;
174 for (std::size_t i = 0, j = 0; i < new_shape.size(); i++) {
175 if (new_shape[i] != 1) new_shape[i] =
values_.shape()[j++];
187 std::none_of(tensor_dim.begin(), tensor_dim.end(),
188 [](
const Expression &e) { return get_constant_value(e) < 0 && get_constant_value(e) != -1; }),
189 "The dimension must be a positive integer or -1");
190 DLINEAR_ASSERT(std::count_if(tensor_dim.begin(), tensor_dim.end(),
191 [](
const Expression &e) { return get_constant_value(e) == -1; }) <= 1,
192 "At most one dimension can be -1");
193 const auto dims =
static_cast<std::vector<std::int64_t>
>(tensor_dim);
194 std::vector<std::size_t> new_dims;
195 new_dims.reserve(tensor_dim.
size());
196 for (
const std::int64_t &
dim :
dims) {
197 if (
dim == 0 && !allow_zero) {
198 new_dims.push_back(
values_.shape(new_dims.size()));
202 new_dims.push_back(
values_.size() / std::reduce(
dims.cbegin(),
dims.cend(), -1, std::multiplies<std::int64_t>{}));
205 new_dims.push_back(
dim);
227 return Slice(
static_cast<std::vector<std::int64_t>
>(starts),
static_cast<std::vector<std::int64_t>
>(ends),
228 static_cast<std::vector<std::int64_t>
>(axes),
static_cast<std::vector<std::int64_t>
>(steps));
231 const std::vector<std::int64_t> &axes,
const std::vector<std::int64_t> &steps) {
232 if (starts.empty() || ends.empty()) DLINEAR_OUT_OF_RANGE(
"Starts and ends must not be empty");
233 if (starts.size() != ends.size()) DLINEAR_OUT_OF_RANGE(
"Starts and ends must have the same size");
234 if (!axes.empty() && axes.size() != starts.size()) DLINEAR_OUT_OF_RANGE(
"Axes must have the same size as starts");
235 if (!steps.empty() && steps.size() != starts.size()) DLINEAR_OUT_OF_RANGE(
"Steps must have the same size as starts");
237 xt::xstrided_slice_vector sv(
values_.dimension(), xt::all());
238 for (std::size_t i = 0; i < starts.size(); i++) {
239 const std::int64_t start =
240 starts[i] < 0 ? starts[i] +
dim(axes.empty() ?
static_cast<std::int64_t
>(i) : axes[i]) : starts[i];
241 std::int64_t end = ends[i] < 0 ? ends[i] +
dim(axes.empty() ?
static_cast<std::int64_t
>(i) : axes[i]) : ends[i];
242 const std::int64_t axis = axes.empty() ?
static_cast<std::int64_t
>(i) : axes[i];
243 const std::int64_t step = steps.empty() ? 1 : steps[i];
244 end = std::min(end,
dim(axis));
245 if (start >=
dim(axis)) DLINEAR_OUT_OF_RANGE_FMT(
"Invalid start value: {}", start);
246 if (step <= 0) DLINEAR_OUT_OF_RANGE_FMT(
"Invalid step value: {}", step);
247 if (start >= end) DLINEAR_OUT_OF_RANGE_FMT(
"Invalid slice: start {} >= end {}", start, end);
248 sv[axis] = xt::range(start, end, step);
259 const std::size_t normalized_axis = axis < 0 ?
values_.dimension() + axis : axis;
261 for (
const Tensor &t : rhs)
values = xt::concatenate(xt::xtuple(
values, t.values_), normalized_axis);
266 if (axis < 0 || axis >=
static_cast<std::int64_t
>(
values_.dimension()))
267 DLINEAR_OUT_OF_RANGE_FMT(
"Invalid axis. Must be in [{}, {}]", 0,
values_.dimension());
270 std::vector<std::int64_t> new_shape{};
271 new_shape.insert(new_shape.end(),
values_.shape().begin(),
values_.shape().begin() + axis);
272 new_shape.insert(new_shape.end(), indices.
values_.shape().begin(), indices.
values_.shape().end());
273 new_shape.insert(new_shape.end(),
values_.shape().begin() + axis + 1,
values_.shape().end());
274 xt::xarray<Expression> new_values = xt::zeros<Expression>(new_shape);
277 for (
const auto &index : indices) {
278 xt::xstrided_slice_vector data_slices{};
279 xt::xstrided_slice_vector new_values_slices{};
280 for (
int i = 0; i < axis; ++i) {
281 data_slices.emplace_back(xt::all());
282 new_values_slices.emplace_back(xt::all());
284 for (
size_t i = 1; i < indices.
ndim(); ++i) {
285 new_values_slices.emplace_back(0);
287 data_slices.emplace_back(get_constant_value(index).get_num().get_ui());
288 new_values_slices.emplace_back(counter++);
289 data_slices.emplace_back(xt::ellipsis());
290 new_values_slices.emplace_back(xt::ellipsis());
292 auto data_slice = xt::strided_view(
values_, data_slices);
293 auto new_slice = xt::strided_view(new_values, new_values_slices);
294 for (
size_t j = 0; j < data_slice.size(); ++j) {
295 new_slice(j) = data_slice(j);
299 return Tensor{new_values};
303 const std::vector<std::int64_t> &kernel_shape,
const std::vector<std::int64_t> &pads,
304 const std::vector<std::int64_t> &stride)
const {
305 DLINEAR_ASSERT(
values_.dimension() == 4,
"Convolution can only be applied to a 4D tensors");
306 DLINEAR_ASSERT(w.
values_.dimension() == 4,
"Convolution can only be applied to a 4D tensors");
307 DLINEAR_ASSERT(
values_.shape()[1] == w.
values_.shape()[1] * group,
308 "The number of input channels must be equal to the number of output channels times the group");
309 DLINEAR_ASSERT(w.
values_.shape()[0] % group == 0,
"The number of output channels must be divisible by the group");
310 DLINEAR_ASSERT(group == 1,
"Group convolution is not supported yet");
312 [[maybe_unused]]
const std::size_t batch_size =
values_.shape()[0];
313 [[maybe_unused]]
const std::size_t input_channels =
values_.shape()[1];
314 const std::vector<std::size_t> remaining_input_shapes{
values_.shape().begin() + 2,
values_.shape().end()};
316 [[maybe_unused]]
const std::size_t feature_map = w.
values_.shape()[0];
317 DLINEAR_ASSERT(w.
values_.shape()[1] == input_channels / group,
318 "The number of input channels must be equal to the number of output channels");
319 const std::vector<std::size_t> remaining_kernel_shapes{w.
values_.shape().begin() + 2, w.
values_.shape().end()};
321 const auto image = xt::view(
values_, 0ul, 0ul);
322 std::vector<std::size_t> new_shape{};
323 for (std::size_t i = 0; i < image.shape().
size(); i++) {
324 const std::size_t pad_offset = pads.size() / 2;
326 1 + (image.shape()[i] + pads[i] + pads[i + pad_offset] - dilation[i] * (w.
values_.shape()[i + 2] - 1) - 1) /
330 std::vector<std::size_t> new_values_shape{1, w.
values_.shape()[0]};
331 new_values_shape.insert(new_values_shape.end(), new_shape.begin(), new_shape.end());
332 xt::xarray<Expression> new_values{new_values_shape};
334 for (std::size_t i = 0; i < w.
values_.shape()[0]; i++) {
335 const auto kernel = xt::view(w.
values_, i, 0ul, xt::range(0, kernel_shape[0]), xt::range(0, kernel_shape[1]));
336 xt::xarray<Expression> row_values{
Convolution(image, kernel, new_shape, dilation, group, pads, stride)};
338 for (std::size_t j = 1; j <
values_.shape()[1]; j++) {
340 xt::view(w.
values_, i, j, xt::range(0, kernel_shape[0]), xt::range(0, kernel_shape[1])),
341 new_shape, dilation, group, pads, stride);
343 auto new_values_view = xt::view(new_values, 0l, i, xt::all(), xt::all());
344 std::size_t counter = 0;
346 e = row_values.flat(counter++);
350 return Tensor{std::move(new_values)};
353 const std::vector<std::size_t> &new_shape,
354 const std::vector<std::int64_t> &dilation, std::int64_t,
355 const std::vector<std::int64_t> &pads,
356 const std::vector<std::int64_t> &stride)
const {
357 DLINEAR_ASSERT(pads.size() == 4,
"Pads must have 4 elements");
358 DLINEAR_ASSERT(dilation.size() == 2,
"Dilations must have 2 elements");
359 DLINEAR_ASSERT(stride.size() == 2,
"Strides must have 2 elements");
360 DLINEAR_ASSERT(image.dimension() == 2,
"Image must be a 2D tensor");
361 DLINEAR_ASSERT(kernel.dimension() == 2,
"Kernel must be a 2D tensor");
362 xt::xarray<Expression> new_values{xt::zeros<Expression>(new_shape)};
364 std::size_t out_r = 0;
365 std::size_t out_c = 0;
366 const auto ih =
static_cast<std::int64_t
>(image.shape()[0]);
367 const auto iw =
static_cast<std::int64_t
>(image.shape()[1]);
368 const auto kh =
static_cast<std::int64_t
>(kernel.shape()[0]);
369 const auto kw =
static_cast<std::int64_t
>(kernel.shape()[1]);
370 const std::int64_t fkmh = kh / 2;
371 const std::int64_t fkmw = kw / 2;
372 const std::int64_t lkmh = kw / 2 - (kh & 1 ? 0 : 1);
373 const std::int64_t lkmw = kw / 2 - (kw & 1 ? 0 : 1);
374 for (std::int64_t r = -pads[0] + fkmh * dilation[0]; r < ih + pads[2] - lkmh * dilation[0]; r += stride[0]) {
375 for (std::int64_t c = -pads[1] + fkmw * dilation[1]; c < iw + pads[3] - lkmw * dilation[1]; c += stride[1]) {
376 new_values(out_r, out_c) = 0;
377 for (std::int64_t i = 0; i < kh; i++) {
378 for (std::int64_t j = 0; j < kw; j++) {
379 const std::int64_t ir = r + (i - fkmh) * dilation[0];
380 const std::int64_t ic = c + (j - fkmw) * dilation[1];
381 new_values(out_r, out_c) +=
382 ir >= 0 && ir < ih && ic >= 0 && ic < iw ? image(ir, ic) * kernel(i, j) :
Expression::Zero();
390 new_values.reshape({1, 1, new_values.shape()[0], new_values.shape()[1]});
395 DLINEAR_ASSERT(
values_.size() > 0,
"Cannot get the maximum value of an empty tensor");
397 if (!is_constant(lhs) || !is_constant(rhs)) DLINEAR_RUNTIME_ERROR_FMT(
"Cannot compare {} and {}", lhs, rhs);
398 return get_constant_value(lhs) < get_constant_value(rhs);
402Tensor Tensor::Squeeze()
const {
return Tensor{xt::squeeze(
values_)}; }
403Tensor Tensor::Squeeze(std::vector<std::int64_t> axes)
const {
404 for (std::int64_t &axis : axes) {
405 if (axis >=
static_cast<std::int64_t
>(
values_.dimension()))
406 DLINEAR_OUT_OF_RANGE_FMT(
"Invalid axis. Must be in [{}, {}]", -
values_.dimension() + 1, -1);
407 if (axis < 0) axis +=
static_cast<std::int64_t
>(
values_.dimension());
413 if ((pads.size() & 1) != 0) DLINEAR_OUT_OF_RANGE(
"Pads must have an even number of elements");
414 if (pads.size() >
values_.dimension() * 2)
415 DLINEAR_OUT_OF_RANGE_FMT(
"Pads must have at most {} elements",
values_.dimension() * 2);
417 std::vector<std::size_t> new_shape(
values_.shape().size(), 0);
418 for (std::size_t i = 0; i <
values_.shape().
size(); i++) {
419 new_shape[i] =
values_.shape()[i] + (i >= pads.size() / 2 ? 0 : pads[i] + pads[i + pads.size() / 2]);
422 xt::xstrided_slice_vector slices(
values_.dimension());
423 for (std::size_t i = 0; i <
values_.dimension(); i++) {
424 const std::size_t offset = i >= pads.size() / 2 ? 0 : pads[i];
425 slices[i] = xt::range(offset,
values_.shape()[i] + offset);
428 xt::xarray<Expression> new_values = xt::zeros<Expression>(new_shape);
430 for (
Expression &e : xt::strided_view(new_values, slices)) {
433 return Tensor{new_values};
438 DLINEAR_RUNTIME_ERROR(
"MatMul can only be applied to Matrices and Vectors");
439 if (
dim(1) != rhs.
dim(0)) {
442 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid MatMul between [{}x{}] and [{}x{}]",
dim(0),
dim(1), rhs.
dim(0), rhs.
dim(1));
444 DLINEAR_ASSERT(
dim(0) > 0 &&
dim(1) > 0 && rhs.
dim(0) > 0 && rhs.
dim(1) > 0,
"All dimensions must be > 0");
446 for (std::int64_t row = 0; row <
dim(0); row++) {
447 for (std::int64_t col = 0; col < rhs.
dim(1); col++) {
448 new_tensor(row, col) = (*this)(row, 0l) * rhs(0l, col);
449 for (std::int64_t inner = 1; inner <
dim(1); inner++) {
450 new_tensor(row, col) += (*this)(row, inner) * rhs(inner, col);
454 if (
values_.dimension() != 2) {
455 new_tensor.values_.reshape({rhs.
dim(1)});
456 }
else if (rhs.
values_.dimension() != 2) {
457 new_tensor.values_.reshape({
dim(0)});
463 if (is_constant(rhs) && get_constant_value(rhs) == 0)
return *
this;
467Tensor &Tensor::operator-=(
const Expression &rhs) {
468 if (is_constant(rhs) && get_constant_value(rhs) == 0)
return *
this;
472Tensor &Tensor::operator*=(
const Expression &rhs) {
473 if (is_constant(rhs) && get_constant_value(rhs) == 1)
return *
this;
477Tensor &Tensor::operator/=(
const Expression &rhs) {
478 if (is_constant(rhs) && get_constant_value(rhs) == 1)
return *
this;
483Tensor &Tensor::operator+=(
const Tensor &rhs) {
484 if (rhs.values_.size() == 1)
return *
this += rhs.values_.flat(0);
488Tensor &Tensor::operator-=(
const Tensor &rhs) {
489 if (rhs.values_.size() == 1)
return *
this -= rhs.values_.flat(0);
493Tensor &Tensor::operator*=(
const Tensor &rhs) {
494 if (rhs.values_.size() == 1)
return *
this *= rhs.values_.flat(0);
498Tensor &Tensor::operator/=(
const Tensor &rhs) {
499 if (rhs.values_.size() == 1)
return *
this /= rhs.values_.flat(0);
504std::vector<Formula> Tensor::operator<(
const Tensor &rhs)
const {
505 if (
values_.size() == 1 && rhs.values_.size() == 1)
return {
values_.flat(0) < rhs.values_.flat(0)};
506 if (
values_.shape() != rhs.values_.shape())
507 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid comparison between {} and {}",
values_.shape(), rhs.values_.shape());
508 std::vector<Formula> formulas;
509 formulas.reserve(
values_.size());
510 for (std::size_t i = 0; i <
values_.size(); i++) formulas.push_back(
values_.flat(i) < rhs[i]);
513std::vector<Formula> Tensor::operator<=(
const Tensor &rhs)
const {
514 if (
values_.size() == 1 && rhs.values_.size() == 1)
return {
values_.flat(0) <= rhs.values_.flat(0)};
515 if (
values_.shape() != rhs.values_.shape())
516 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid comparison between {} and {}",
values_.shape(), rhs.values_.shape());
517 std::vector<Formula> formulas;
518 formulas.reserve(
values_.size());
519 for (std::size_t i = 0; i <
values_.size(); i++) formulas.push_back(
values_.flat(i) <= rhs[i]);
522std::vector<Formula> Tensor::operator>(
const Tensor &rhs)
const {
523 if (
values_.size() == 1 && rhs.values_.size() == 1)
return {
values_.flat(0) > rhs.values_.flat(0)};
524 if (
values_.shape() != rhs.values_.shape())
525 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid comparison between {} and {}",
values_.shape(), rhs.values_.shape());
526 std::vector<Formula> formulas;
527 formulas.reserve(
values_.size());
528 for (std::size_t i = 0; i <
values_.size(); i++) formulas.push_back(
values_.flat(i) > rhs[i]);
531std::vector<Formula> Tensor::operator>=(
const Tensor &rhs)
const {
532 if (
values_.size() == 1 && rhs.values_.size() == 1)
return {
values_.flat(0) >= rhs.values_.flat(0)};
533 if (
values_.shape() != rhs.values_.shape())
534 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid comparison between {} and {}",
values_.shape(), rhs.values_.shape());
535 std::vector<Formula> formulas;
536 formulas.reserve(
values_.size());
537 for (std::size_t i = 0; i <
values_.size(); i++) formulas.push_back(
values_.flat(i) >= rhs[i]);
540std::vector<Formula> Tensor::operator==(
const Tensor &rhs)
const {
541 if (
values_.size() == 1 && rhs.values_.size() == 1)
return {
values_.flat(0) == rhs.values_.flat(0)};
542 if (
values_.shape() != rhs.values_.shape())
543 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid comparison between {} and {}",
values_.shape(), rhs.values_.shape());
544 std::vector<Formula> formulas;
545 formulas.reserve(
values_.size());
546 for (std::size_t i = 0; i <
values_.size(); i++) formulas.push_back(
values_.flat(i) == rhs[i]);
549std::vector<Formula> Tensor::operator!=(
const Tensor &rhs)
const {
550 if (
values_.size() == 1 && rhs.values_.size() == 1)
return {
values_.flat(0) != rhs.values_.flat(0)};
551 if (
values_.shape() != rhs.values_.shape())
552 DLINEAR_RUNTIME_ERROR_FMT(
"Invalid comparison between {} and {}",
values_.shape(), rhs.values_.shape());
553 std::vector<Formula> formulas;
554 formulas.reserve(
values_.size());
555 for (std::size_t i = 0; i <
values_.size(); i++) formulas.push_back(
values_.flat(i) != rhs[i]);
559Expression &Tensor::operator[](
const int index) {
return values_.flat(index); }
560const Expression &Tensor::operator[](
const int index)
const {
return values_.flat(index); }
561Expression &Tensor::operator[](
const std::size_t index) {
return values_.flat(index); }
562const Expression &Tensor::operator[](
const std::size_t index)
const {
return values_.flat(index); }
564Expression &Tensor::operator()(std::initializer_list<std::int64_t> dims) {
return values_.flat(
ComputeOffset(
dims)); }
565const Expression &Tensor::operator()(std::initializer_list<std::int64_t> dims)
const {
569Tensor::operator std::vector<std::int64_t>()
const {
570 std::vector<std::int64_t> result;
571 result.reserve(values_.size());
572 for (
const Expression &e : values_) {
573 DLINEAR_ASSERT(is_constant(e),
"Values must constants");
574 DLINEAR_ASSERT(get_constant_value(e).get_den().get_ui() == 1,
"Values must be integers");
575 result.push_back(e.Evaluate().get_num().get_si());
579Tensor::operator std::vector<double>()
const {
580 std::vector<double> result;
581 result.reserve(values_.size());
582 for (
const Expression &e : values_) {
583 DLINEAR_ASSERT(is_constant(e),
"Values must constants");
584 result.push_back(e.Evaluate().get_d());
588Tensor::operator std::vector<std::size_t>()
const {
589 std::vector<std::size_t> result;
590 result.reserve(values_.size());
591 for (
const Expression &e : values_) {
592 DLINEAR_ASSERT(is_constant(e),
"Values must constants");
593 DLINEAR_ASSERT(get_constant_value(e).get_den().get_ui() == 1,
"Values must be integers");
594 result.push_back(e.Evaluate().get_num().get_ui());
600 const std::size_t being_offset =
dims.size() >
values_.dimension() ?
dims.size() -
values_.dimension() : 0;
604 DLINEAR_ASSERT(
size <=
values_.dimension(),
"Invalid number of dimensions");
605 std::size_t offset = 0;
606 std::size_t stride = 1;
607 for (std::size_t i = 0; i <
size; i++) {
608 offset +=
dims[
size - i - 1] * stride;
609 stride *=
values_.shape().rbegin()[
static_cast<std::int64_t
>(i)];
614std::ostream &operator<<(std::ostream &os,
const Tensor &tensor) {
615 return os <<
"Tensor(" << tensor.
values().shape() <<
")\n" << tensor.
values();
617std::ostream &operator<<(std::ostream &os,
const xt::xarray<dlinear::Expression> &values) {
618 for (
const Expression &e : values) os << e <<
'\n';
621std::ostream &operator<<(std::ostream &os,
const xt::xarray<dlinear::Expression>::shape_type &shape) {
622 for (
const std::size_t dim : shape) os << dim <<
' ';
Represents a symbolic form of an expression.
static Expression Zero()
Returns zero.
Represents a symbolic variable.
std::size_t ComputeOffset(std::initializer_list< std::int64_t > dims) const
Given a set of indices dims, compute the offset of the tensor.
Tensor & Reshape(std::initializer_list< std::int64_t > dims)
Reshape the tensor with the given dims.
bool SameDim(const Tensor &o) const
Check whether the two tensors have the same dimension.
std::size_t size() const
Get read-only access to the size of the tensor.
Tensor Convolution(const Tensor &w, const std::vector< std::int64_t > &dilations, std::int64_t group, const std::vector< std::int64_t > &kernel_shape, const std::vector< std::int64_t > &pads, const std::vector< std::int64_t > &strides) const
Convolution of two tensors.
std::size_t ndim() const
Get read-only access to the number of dimensions of the tensor.
Tensor MatMul(const Tensor &tensor) const
Matrix multiplication of two tensors.
Tensor & Transpose(const std::vector< std::int64_t > &perm={})
Transpose the tensor with the given perm.
Tensor & Abs()
Apply the Abs function to the tensor.
Tensor(std::initializer_list< std::int64_t > dims)
Construct a tensor with the given dims.
Tensor Gather(const Tensor &indices, std::int64_t axis)
Gather the tensor with the given indices along the given axis.
Tensor & Slice(const std::vector< std::int64_t > &starts, const std::vector< std::int64_t > &ends, const std::vector< std::int64_t > &axes={}, const std::vector< std::int64_t > &steps={})
Slice the tensor with the given starts, ends, axes, and steps.
Tensor Concat(const Tensor &rhs, std::int64_t axis)
Concatenate the tensor with the given rhs along the given axis.
bool Equal(const Tensor &o) const
Compare two tensor to determine if they are equal.
Tensor Pad(const std::vector< std::int64_t > &pads) const
Pad the tensor with the given pads.
std::vector< std::int64_t > dims() const
Get read-only access to the dimensions of the tensor.
std::int64_t dim(std::size_t i) const
Get the dimension at index i of the tensor.
Tensor & Unsqueeze(const Tensor &axes)
Insert single-dimensional entries to the shape of an input tensor.
xt::xarray< Expression > values_
Internal storage of the values of the tensor.
Tensor & Elementwise(const std::function< Expression(Expression)> &f)
Apply the f function to each element of the tensor.
Tensor & Flatten(std::int64_t axis)
Flatten the tensor along the given axis.
const xt::xarray< Expression > & values() const
Get read-only access to the values of the tensor.
Namespace for the ONNX parser of the dlinear library.
@ MatMul
Matrix multiplication.