TensorGraph
The TensorGraph class is a powerful abstraction for managing and executing tensor operations within a computational graph. It provides an efficient mechanism for building, modifying, and dispatching operations on tensors with a high degree of flexibility and control over execution parameters, formats, and flags. Whether working with basic operations or advanced deep learning techniques, this class offers a robust interface for numerical computation.
#include <parallel/TellusimTensorGraph.h>
Constructors
TensorGraph()
Methods
Clear graph.
void clear()
Check graph.
bool isCreated() const
Create graph.
bool create(const Device &device, Flags flags = FlagsAll, Masks masks = MasksAll, Async *async = nullptr)
Dispatch Tensor operation.
bool dispatch(Compute &compute, Operation op, Tensor dest, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Tensor src_1, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Tensor src_1, Tensor src_2, Flags flags = FlagNone) const
| Type | Name | Description |
|---|---|---|
| TensorGraph::Operation | op | Graph operation. |
| TensorGraph::Flags | flags | Operation flags. |
| Tensor | dest | Destination tensor. |
| Tensor | src_0 | Source tensors. |
bool dispatch(Compute &compute, Tensor dest, Texture &src, Region region, Slice slice = Slice()) const
bool dispatch(Compute &compute, Tensor dest, Texture &src, Slice slice = Slice()) const
bool dispatch(Compute &compute, Texture &dest, Tensor src, Region region, Slice slice = Slice()) const
bool dispatch(Compute &compute, Texture &dest, Tensor src, Slice slice = Slice()) const
Enums
Operation
Graph operations.
| Name | Value | Description |
|---|---|---|
| Clear | 0 | Clear tensor values. |
| Range | 1 | Init tensor values from 0 to size. |
| Copy | 2 | Copy tensor with the same or different layout. |
| Cat | 3 | Concatenates two tensors. |
| Transpose | 4 | Transpose tensor. |
| MatMul | 5 | Matrix multiplication. |
| Mul | 6 | Value multiplication. |
| Mad | 7 | Value multiplication and addition. |
| Div | 8 | Value division. |
| Add | 9 | Value addition. |
| Conv | 10 | Convolution. |
| DeConv | 11 | Deconvolution. |
| BatchNorm | 12 | Batch normalization. |
| BatchMad | 13 | Batch multiplication and addition. |
| SoftMin | 14 | Softmin function. |
| SoftMax | 15 | Softmax function. |
| MaxPool | 16 | Maximum pooling. |
| AvgPool | 17 | Average pooling. |
| GELU | 18 | Gaussian error linear unit function. |
| ReLU | 19 | Rectified linear unit function. |
| SiLU | 20 | Sigmoid linear unit function. |
| Sigm | 21 | Sigmoid function. |
| Tanh | 22 | Tanh function. |
| Sin | 23 | Sin function. |
| Cos | 24 | Cos function. |
| Exp | 25 | Exp function. |
| NumOperations | 26 |
Flags
Graph flags.
| Name | Value |
|---|---|
| FlagNone | 0 |
| FlagSizeQuery | (1 << 0) |
| FlagFormatRf32 | (1 << 1) |
| FlagFormatRf16 | (1 << 2) |
| FlagTranspose | (1 << 3) |
| FlagWrapClamp | (1 << 4) |
| FlagWrapRepeat | (1 << 5) |
| FlagReadScale | (1 << 6) |
| FlagReadBias | (1 << 7) |
| FlagConvert | (1 << 8) |
| FlagKernel | (1 << 9) |
| FlagGELU | (1 << 10) |
| FlagReLU | (1 << 11) |
| FlagSiLU | (1 << 12) |
| FlagSigm | (1 << 13) |
| FlagTanh | (1 << 14) |
| FlagSin | (1 << 15) |
| FlagCos | (1 << 16) |
| FlagExp | (1 << 17) |
| FlagFormat | FlagFormatRf32 | FlagFormatRf16 |
| FlagWrap | FlagWrapClamp | FlagWrapRepeat |
| FlagRead | FlagReadScale | FlagReadBias |
| FlagUnit | FlagGELU | FlagReLU | FlagSiLU |
| FlagMath | FlagSigm | FlagTanh | FlagSin | FlagCos | FlagExp |
| FlagsAll | FlagFormat | FlagTranspose | FlagWrap | FlagRead | FlagConvert | FlagKernel | FlagUnit | FlagMath |
Masks
Graph masks.
| Name | Value |
|---|---|
| MaskNone | 0 |
| MaskClear | (1 << Clear) |
| MaskRange | (1 << Range) |
| MaskCopy | (1 << Copy) |
| MaskCat | (1 << Cat) |
| MaskTranspose | (1 << Transpose) |
| MaskMatMul | (1 << MatMul) |
| MaskMul | (1 << Mul) |
| MaskMad | (1 << Mad) |
| MaskDiv | (1 << Div) |
| MaskAdd | (1 << Add) |
| MaskConv | (1 << Conv) |
| MaskDeConv | (1 << DeConv) |
| MaskBatchNorm | (1 << BatchNorm) |
| MaskBatchMad | (1 << BatchMad) |
| MaskSoftMin | (1 << SoftMin) |
| MaskSoftMax | (1 << SoftMax) |
| MaskMaxPool | (1 << MaxPool) |
| MaskAvgPool | (1 << AvgPool) |
| MaskGELU | (1 << GELU) |
| MaskReLU | (1 << ReLU) |
| MaskSiLU | (1 << SiLU) |
| MaskSigm | (1 << Sigm) |
| MaskTanh | (1 << Tanh) |
| MaskSin | (1 << Sin) |
| MaskCos | (1 << Cos) |
| MaskExp | (1 << Exp) |
| MasksAll | (1 << NumOperations) - 1 |