TensorGraph
The TensorGraph class is a powerful abstraction for managing and executing tensor operations within a computational graph. It provides an efficient mechanism for building, modifying, and dispatching operations on tensors with a high degree of flexibility and control over execution parameters, formats, and flags. Whether working with basic operations or advanced deep learning techniques, this class offers a robust interface for numerical computation.
#include <parallel/TellusimTensorGraph.h>
Constructors
TensorGraph()
Methods
Clear graph.
void clear()
Check graph.
bool isCreated() const
Create graph.
bool create(const Device &device, Flags flags = FlagsAll, Masks masks = MasksAll, Async *async = nullptr)
Dispatch Tensor operation.
bool dispatch(Compute &compute, Operation op, Tensor dest, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Tensor src_1, Flags flags = FlagNone) const
bool dispatch(Compute &compute, Operation op, Tensor dest, Tensor src_0, Tensor src_1, Tensor src_2, Flags flags = FlagNone) const
Type | Name | Description |
---|---|---|
TensorGraph::Operation | op | Graph operation. |
TensorGraph::Flags | flags | Operation flags. |
Tensor | dest | Destination tensor. |
Tensor | src_0 | Source tensors. |
bool dispatch(Compute &compute, Tensor dest, Texture &src, Region region, Slice slice = Slice()) const
bool dispatch(Compute &compute, Tensor dest, Texture &src, Slice slice = Slice()) const
bool dispatch(Compute &compute, Texture &dest, Tensor src, Region region, Slice slice = Slice()) const
bool dispatch(Compute &compute, Texture &dest, Tensor src, Slice slice = Slice()) const
Enums
Operation
Graph operations.
Name | Value | Description |
---|---|---|
Clear | 0 | Clear tensor values. |
Range | 1 | Init tensor values from 0 to size. |
Copy | 2 | Copy tensor with the same or different layout. |
Cat | 3 | Concatenates two tensors. |
Transpose | 4 | Transpose tensor. |
MatMul | 5 | Matrix multiplication. |
Mul | 6 | Value multiplication. |
Mad | 7 | Value multiplication and addition. |
Div | 8 | Value division. |
Add | 9 | Value addition. |
Conv | 10 | Convolution. |
DeConv | 11 | Deconvolution. |
BatchNorm | 12 | Batch normalization. |
BatchMad | 13 | Batch multiplication and addition. |
SoftMin | 14 | Softmin function. |
SoftMax | 15 | Softmax function. |
MaxPool | 16 | Maximum pooling. |
AvgPool | 17 | Average pooling. |
GELU | 18 | Gaussian error linear unit function. |
ReLU | 19 | Rectified linear unit function. |
SiLU | 20 | Sigmoid linear unit function. |
Sigm | 21 | Sigmoid function. |
Tanh | 22 | Tanh function. |
Sin | 23 | Sin function. |
Cos | 24 | Cos function. |
Exp | 25 | Exp function. |
NumOperations | 26 |
Flags
Graph flags.
Name | Value |
---|---|
FlagNone | 0 |
FlagSizeQuery | (1 << 0) |
FlagFormatRf32 | (1 << 1) |
FlagFormatRf16 | (1 << 2) |
FlagTranspose | (1 << 3) |
FlagWrapClamp | (1 << 4) |
FlagWrapRepeat | (1 << 5) |
FlagReadScale | (1 << 6) |
FlagReadBias | (1 << 7) |
FlagConvert | (1 << 8) |
FlagKernel | (1 << 9) |
FlagGELU | (1 << 10) |
FlagReLU | (1 << 11) |
FlagSiLU | (1 << 12) |
FlagSigm | (1 << 13) |
FlagTanh | (1 << 14) |
FlagSin | (1 << 15) |
FlagCos | (1 << 16) |
FlagExp | (1 << 17) |
FlagFormat | FlagFormatRf32 | FlagFormatRf16 |
FlagWrap | FlagWrapClamp | FlagWrapRepeat |
FlagRead | FlagReadScale | FlagReadBias |
FlagUnit | FlagGELU | FlagReLU | FlagSiLU |
FlagMath | FlagSigm | FlagTanh | FlagSin | FlagCos | FlagExp |
FlagsAll | FlagFormat | FlagTranspose | FlagWrap | FlagRead | FlagConvert | FlagKernel | FlagUnit | FlagMath |
Masks
Graph masks.
Name | Value |
---|---|
MaskNone | 0 |
MaskClear | (1 << Clear) |
MaskRange | (1 << Range) |
MaskCopy | (1 << Copy) |
MaskCat | (1 << Cat) |
MaskTranspose | (1 << Transpose) |
MaskMatMul | (1 << MatMul) |
MaskMul | (1 << Mul) |
MaskMad | (1 << Mad) |
MaskDiv | (1 << Div) |
MaskAdd | (1 << Add) |
MaskConv | (1 << Conv) |
MaskDeConv | (1 << DeConv) |
MaskBatchNorm | (1 << BatchNorm) |
MaskBatchMad | (1 << BatchMad) |
MaskSoftMin | (1 << SoftMin) |
MaskSoftMax | (1 << SoftMax) |
MaskMaxPool | (1 << MaxPool) |
MaskAvgPool | (1 << AvgPool) |
MaskGELU | (1 << GELU) |
MaskReLU | (1 << ReLU) |
MaskSiLU | (1 << SiLU) |
MaskSigm | (1 << Sigm) |
MaskTanh | (1 << Tanh) |
MaskSin | (1 << Sin) |
MaskCos | (1 << Cos) |
MaskExp | (1 << Exp) |
MasksAll | (1 << NumOperations) - 1 |